Dlaczego musimy wywoływać zero_grad () w PyTorch?

Odpowiedzi:

165

W programie PyTorchmusimy ustawić gradienty na zero, zanim zaczniemy wykonywać propagację wsteczną, ponieważ PyTorch gromadzi gradienty przy kolejnych przebiegach wstecz. Jest to wygodne podczas szkolenia RNN. Zatem domyślną akcją jest gromadzenie (tj. Sumowanie) gradientów przy każdym loss.backward()wywołaniu.

Z tego powodu, rozpoczynając pętlę treningową, najlepiej byłoby, gdybyś zero out the gradientspoprawnie zaktualizował parametry. W przeciwnym razie gradient wskazywałby inny kierunek niż zamierzony kierunek w kierunku minimum (lub maksimum , w przypadku celów maksymalizacji).

Oto prosty przykład:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternatywnie, jeśli robisz zejście w gradiencie waniliowym , to:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Uwaga : Akumulacja (tj. Suma ) gradientów ma miejsce, gdy .backward()zostanie wywołana przez losstensor .

kmario23
źródło
3
bardzo dziękuję, to jest naprawdę pomocne! Czy wiesz, czy tensorflow ma takie zachowanie?
layser
Dla pewności… jeśli tego nie zrobisz, napotkasz problem z eksplodującym gradientem, prawda?
zwep
3
@zwep Jeśli gromadzimy gradienty, nie oznacza to, że ich wielkość wzrasta: przykładem może być sytuacja, w której znak gradientu będzie się zmieniał. Więc to nie gwarantuje, że napotkasz problem z eksplodującym gradientem. Poza tym wybuchające gradienty istnieją, nawet jeśli wyzerujesz poprawnie.
Tom Roth,
Czy po uruchomieniu gradientu waniliowego nie pojawia się błąd „Zmienna typu liść, która wymaga użycia gradientu w operacji w miejscu”, gdy próbujesz zaktualizować wagi?
MUAS
3

zero_grad() uruchamia ponownie pętlę bez strat od ostatniego kroku, jeśli używasz metody gradientu w celu zmniejszenia błędu (lub strat).

Jeśli nie używasz, zero_grad()strata zmniejszy się, a nie zwiększy się zgodnie z wymaganiami.

Na przykład:

Jeśli używasz zero_grad(), otrzymasz następujące dane wyjściowe:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

Jeśli nie używasz zero_grad(), otrzymasz następujące dane wyjściowe:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
Youssri Abo Elseod
źródło