1.optimizer構築
optimizer = torch.optim.SGD(list(encoder.parameters())+list(classifier.parameters()), lr=0.01)
2.計算グラフの構築
z = encoder(input_src_img)
digit_pred = classifier(z)
3.loss算出
loss = nn.CrossEntropyLoss()(digit_pred,digit_label)
4.誤差逆伝播
loss.backward()
損失に関して計算グラフを微分して各変数のgradに勾配を入れる。この時点ではパラメータは更新されていない。なお、パラメータはprint(x.grad)で確認可能。
注意点は計算グラフは揮発性のため、backword()を実行すると消えてしまう。lossを複数回に分けて計算したいならばbackward(retain_variables=True)とすること。
5.パラメータ更新
optimizer.step()
補足1.
パラメータ更新後はoptimizer.zero_grad()で勾配初期化しておく。そうしないとloss.backward()で計算された勾配が蓄積してしまう。呼び出すタイミングはloss.backward()の直前でもよい。
補足2.マルチタスク
次のようにすると効率的な計算が可能。ただしGANのような順番に学習することに意味があるタスクには不適切なので要注意。
(loss1 + loss2).backward()