Jakie jest zastosowanie torch.no_grad w pytorch?

21

Jestem nowy w pytorch i zacząłem z tym kodem github. Nie rozumiem komentarza w wierszu 60–61 w kodzie "because weights have requires_grad=True, but we don't need to track this in autograd". Zrozumiałem, że wspominamy requires_grad=Trueo zmiennych, które musimy obliczyć gradienty przy użyciu autograd, ale co to znaczy być "tracked by autograd"?

flyingDope
źródło

Odpowiedzi:

24

Opakowanie „z torch.no_grad ()” tymczasowo ustawia wszystkie flagi wymaga_grada na wartość false. Przykład z oficjalnego samouczka PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Na zewnątrz:

True
True
False

Polecam przeczytanie wszystkich samouczków ze strony powyżej.

W twoim przykładzie: Myślę, że autor nie chce, aby PyTorch obliczał gradienty nowych zdefiniowanych zmiennych w1 i w2, ponieważ po prostu chce zaktualizować ich wartości.

Adrien D.
źródło
6
with torch.no_grad()

sprawi, że wszystkie operacje w bloku nie będą miały gradientów.

W pytorch nie można dokonywać zmian zastępczych w1 i w2, które są dwiema zmiennymi require_grad = True. Myślę, że unikanie zmiany położenia w1 i w2 wynika z tego, że spowoduje to błąd w obliczeniach propagacji wstecznej. Ponieważ zmiana zastępcza całkowicie zmieni w1 i w2.

Jeśli jednak użyjesz tego no_grad(), możesz kontrolować, że nowy w1 i nowy w2 nie mają gradientów, ponieważ są one generowane przez operacje, co oznacza, że ​​zmieniasz tylko wartość w1 i w2, a nie część gradientową, nadal mają one wcześniej zdefiniowane informacje o zmiennej gradientowej i wsteczna propagacja może być kontynuowana.

Jianing Lu
źródło