Loading [MathJax]/jax/output/HTML-CSS/jax.js

大阪大学医学部 Python会

Now is better than never.

方向微分によるニューラルネットワークの勾配近似

2022-02-20(Sun) - Posted by 山本 in 技術ブログ    tag:Python tag:Machine Learning tag:論文まとめ

この記事では(Silver, et al., "Learning by Directional Gradient Descent." ICLR. 2021)および(Baydin, et al., "Gradients without Backpropagation", arXiv, 2022)の解説&実装を行います.

両者とも摂動 (perturbation)方向微分(directional gradient) を用いて勾配を近似することで,誤差逆伝播法 (backpropagation)を用いずにニューラルネットワークを訓練するという手法を提案しています.gradient-free optimizationの一種とも言えるでしょう.以後,Silverらの提案手法をDODGE(Deep Online Directional Gradient Estimate), Baydinらの提案手法をFGD (Forward gradient descent)と呼ぶことにします.

この手法の利点としては近似勾配の計算にニューラルネットワークの順伝播のみしか用いないため,パラメータを並列に更新することができる,ということが挙げられます.Baydinらは脳(神経回路網)の学習則にも触れていますが,biologicalなモデルに落とし込むとすれば次のように解釈できると思います:

各シナプスにおいてランダムかつ微小なシナプス伝達強度の変化 (e.g. spine headの大きさ変化) の方向を"記憶"しておき,損失の方向微分の値というglobal factorを各シナプスにfeedbackした上で"記憶"しておいた微小なシナプス伝達強度変化に乗じて再度シナプス伝達強度を大きく変化させる.

ただし,これはあくまで解釈です.この機構が実現可能かどうかの議論はこの記事ではしないことにします.

下準備1: 摂動による学習法

やや外れた話題ですが,先に摂動 (perturbation) による勾配を使用しない単純な学習法を紹介しておきます.ニューラルネットワークのパラメータをθRp, データサンプルをx,損失関数をL(θ,x)とします.また,パラメータへの摂動をvRpとします.ここで単純な学習法とは,「パラメータに摂動を加えて損失が下がったらそのパラメータに更新する」です.

ΔL=L(θ+v,x)L(θ,x)if ΔL<0 :θθ+v

一応学習は進みますが,効率的ではありません.誤差逆伝播法を用いない学習法の研究においてベースラインとしてよく用いられます.

下準備2: 方向微分とJacobian-vector productの計算

本記事で紹介する学習則では方向微分(directional gradient) というものが用いられます.関数fについて点uにおける方向vの方向微分は

vf(u)=limh0f(u+hv)f(u)h

として定義されます.またfが点uにおいて微分可能なら

vf(u)=f(u)v

が成り立ちます.ここで右辺をJacobian-vector product (JVP) と呼びます.JVPを計算する上でSilverらはForward Mode ADで計算できるjax.jvpを用いています.BaydinらはPytorch実装かつ自動微分部分は自前実装したようです.Pytorchにもtorch.autograd.functional.jvpがありますが,"double backwards trick"というbackwardsを2回用いる手法を用いているので勾配が必要になります.

妥協案として有限差分(finite difference)を用いてJacobian-vector productを近似計算します (ϵは小さい値です).

f(u)vf(u+ϵv)f(u)ϵ

なお,f(u)Rの場合,f(u)vRとなります.有限差分で近似計算ができることを簡単な関数 (cos)で確かめてみましょう.先に今回使うライブラリを全てimportしておきます.

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt

関数を定義します.今回は以下で定義されるBeale functionを用います.

f(x,y)=(1.5x+xy)2+(2.25x+xy2)2+(2.625x+xy3)2
In [2]:
def func(x):
    return (1.5 - x[0]+x[0]*x[1])**2+(2.25 - x[0]+x[0]*x[1]**2)**2+ (2.625 - x[0]+x[0]*x[1]**3)**2

方向微分を2通りで計算します.

In [3]:
u, v = torch.tensor([1.5, -0.1], requires_grad=True), torch.ones(2)
# torch func.
func_output, jvp = torch.autograd.functional.jvp(func, u, v)

# finite difference
eps = 1e-3
f_v, f = func(u + eps*v), func(u)
jvp_fd = (f_v - f) / eps
print("torch.autograd.functional.jvp: ", jvp) 
print("finite diff.:", jvp_fd)
torch.autograd.functional.jvp:  tensor(-4.2418)
finite diff.: tensor(-4.2384, grad_fn=<DivBackward0>)

自動微分の結果と有限差分による結果が概ね一致していることがわかります.以後,有限差分を用いますが,今回紹介する学習則に適応する場合,2回forward計算が必要なので効率は悪いことに注意してください.

摂動と方向微分による勾配近似

それでは本題の説明をしましょう.まず,誤差逆伝播法と確率的勾配降下法(SGD)で学習する場合,誤差逆伝播法でL(θ,x)=L(θ,x)θを計算し,SGDにより

θθηL(θ,x)

とパラメータを更新します.ただし,ηは学習率です.一方,FGDとDODGEでは以下のようにパラメータを更新します.

if DODGE:v{1,1}pelse if FGD:vN(0,I)g(θ,x)=(L(θ,x)v)vθθηg(θ,x)

2つの手法は摂動をサンプリングする分布が異なるだけと言えます.ここでL(θ,x)vの計算の際にL(θ,x)を計算する必要がないことに注意してください.また,g(θ,x)L(θ,x)不偏推定量(unbiased estimator) になるということが最も重要な点です.これについての証明はそれぞれの論文に書いてあるのでそちらを参照してください.

以下では数値計算を用いてこの手法で勾配が近似できることを説明します.先ほどの続きとしてcos関数のuにおける勾配を計算します.

In [4]:
grad = torch.autograd.grad(f, u)[0].numpy()
print("True grad: ", grad)
True grad:  [-3.433947   -0.80788505]

勾配を推定する関数を作成します.

In [5]:
def grad_estimation(f, u, n, mode="dodge"):
    if mode == "dodge":
        v = 2*(torch.rand(n, 2) > 0.5) - 1   
    elif mode == "fgd":
        v = torch.randn(n, 2)
    else:
        assert False, "mode is dodge or fgd"
    estimate_grad = 0
    for i in range(n):
        _, jvp = torch.autograd.functional.jvp(f, u, v[i])
        estimate_grad += jvp*v[i]
    estimate_grad /= n
    return estimate_grad
In [6]:
num_directions = np.array([1, 10, 100, 1000])
estimated_grad_dodge, estimated_grad_fgd  = np.zeros((len(num_directions), 2)), np.zeros((len(num_directions), 2))
for i, n in enumerate(num_directions):
    estimated_grad_dodge[i] = grad_estimation(func, u, n, mode="dodge")
    estimated_grad_fgd[i] = grad_estimation(func, u, n, mode="fgd")

結果を描画しましょう.DODGE, FGDのいずれにおいても摂動の数 (num_directions)を増やすことで推定された勾配が真の勾配に近づくことが分かります.複数回実行すればわかりますが,DODGEの方が推定の精度は良いです.(数学的に証明できると思います)

In [7]:
scale = 0.1
px, py = u.detach().numpy()
Xb = np.mgrid[-2:2:1000j, -2:2:1000j]
Zb = func(Xb)

titles = [r"$x$ grad.", r"$y$ grad."]; idx_dir = [0, -1]
plt.figure(figsize=(6, 6), dpi=100)
plt.suptitle("Gradients of Beale function")
for i in range(2):
    plt.subplot(2,2,i+1)
    plt.title(titles[i])
    plt.semilogx(num_directions, estimated_grad_dodge[:, i], "o-", label="Estimated w/ DODGE")
    plt.semilogx(num_directions, estimated_grad_fgd[:, i], "o-", label="Estimated w/ FGD")
    plt.axhline(grad[i], linestyle="--", color="tab:red", label="Actual")
    plt.xlabel("Num. directions"); 
    if i == 0:
        plt.legend()

    plt.subplot(2,2,i+3)
    plt.title("Num. directions:"+str(num_directions[idx_dir[i]]))
    plt.contour(Xb[0], Xb[1], Zb, levels=50)
    plt.scatter(px, py, color="k")
    plt.arrow(px, py, estimated_grad_dodge[idx_dir[i], 0]*scale,  estimated_grad_dodge[idx_dir[i], 1]*scale, head_width=0.1, color='tab:blue', label="Estimated w/ DODGE")
    plt.arrow(px, py, estimated_grad_fgd[idx_dir[i], 0]*scale,  estimated_grad_fgd[idx_dir[i], 1]*scale, head_width=0.1, color='tab:orange', label="Estimated w/ FGD")
    plt.arrow(px, py, grad[0]*scale, grad[1]*scale, head_width=0.1, color='tab:red', label="Actual")
    plt.xlabel(r"$x$"); plt.ylabel(r"$y$") 
    if i == 0:
        plt.legend()
plt.tight_layout()

Pytorchによる実装

ここから実装に入っていきます.MNISTデータセットをbackprop, DODGE, FGDで学習させます.このコードはpytorch tutorialを元に作成しました.

In [8]:
training_data = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
In [9]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
Using cuda device

誤差逆伝播法による訓練

比較のため,勾配を正確に取得できる誤差逆伝播法を用いてモデルを訓練します.バッチサイズ64で15 epoch訓練します.

In [10]:
loss_fn = nn.CrossEntropyLoss()
In [11]:
def train(dataloader, model, loss_fn, optimizer):
    loss_list = []
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            loss_list.append(loss)
    return np.array(loss_list)
In [12]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss
In [13]:
model_bp = NeuralNetwork().to(device)
optimizer_bp = torch.optim.SGD(model_bp.parameters(), lr=1e-2)
train_loss_bp, test_loss_bp = [], []
epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train(train_dataloader, model_bp, loss_fn, optimizer_bp)
    test_loss = test(test_dataloader, model_bp, loss_fn)
    train_loss_bp.append(train_loss)
    test_loss_bp.append(test_loss)
train_loss_bp = np.concatenate(train_loss_bp)
test_loss_bp = np.array(test_loss_bp)
Epoch 1
-------------------------------
loss: 2.299546  [    0/60000]
loss: 2.243087  [ 6400/60000]
loss: 2.193906  [12800/60000]
loss: 1.990515  [19200/60000]
loss: 1.825056  [25600/60000]
loss: 1.497010  [32000/60000]
loss: 1.088965  [38400/60000]
loss: 1.071147  [44800/60000]
loss: 0.839060  [51200/60000]
loss: 0.702301  [57600/60000]
Test Error: 
 Accuracy: 83.7%, Avg loss: 0.645017 

Epoch 2
-------------------------------
loss: 0.738222  [    0/60000]
loss: 0.536721  [ 6400/60000]
loss: 0.543878  [12800/60000]
loss: 0.521499  [19200/60000]
loss: 0.463977  [25600/60000]
loss: 0.448645  [32000/60000]
loss: 0.326735  [38400/60000]
loss: 0.522257  [44800/60000]
loss: 0.468751  [51200/60000]
loss: 0.475883  [57600/60000]
Test Error: 
 Accuracy: 88.9%, Avg loss: 0.391872 

Epoch 3
-------------------------------
loss: 0.428448  [    0/60000]
loss: 0.321109  [ 6400/60000]
loss: 0.325715  [12800/60000]
loss: 0.404429  [19200/60000]
loss: 0.327921  [25600/60000]
loss: 0.372084  [32000/60000]
loss: 0.230002  [38400/60000]
loss: 0.442992  [44800/60000]
loss: 0.386030  [51200/60000]
loss: 0.438659  [57600/60000]
Test Error: 
 Accuracy: 90.4%, Avg loss: 0.332632 

Epoch 4
-------------------------------
loss: 0.320814  [    0/60000]
loss: 0.276162  [ 6400/60000]
loss: 0.247757  [12800/60000]
loss: 0.369032  [19200/60000]
loss: 0.275181  [25600/60000]
loss: 0.333127  [32000/60000]
loss: 0.196659  [38400/60000]
loss: 0.405020  [44800/60000]
loss: 0.338511  [51200/60000]
loss: 0.415403  [57600/60000]
Test Error: 
 Accuracy: 91.3%, Avg loss: 0.302255 

Epoch 5
-------------------------------
loss: 0.263066  [    0/60000]
loss: 0.256540  [ 6400/60000]
loss: 0.207443  [12800/60000]
loss: 0.350262  [19200/60000]
loss: 0.241405  [25600/60000]
loss: 0.307048  [32000/60000]
loss: 0.179711  [38400/60000]
loss: 0.379755  [44800/60000]
loss: 0.300815  [51200/60000]
loss: 0.394544  [57600/60000]
Test Error: 
 Accuracy: 92.0%, Avg loss: 0.280586 

Epoch 6
-------------------------------
loss: 0.223930  [    0/60000]
loss: 0.242454  [ 6400/60000]
loss: 0.182857  [12800/60000]
loss: 0.334236  [19200/60000]
loss: 0.217351  [25600/60000]
loss: 0.287452  [32000/60000]
loss: 0.167590  [38400/60000]
loss: 0.359716  [44800/60000]
loss: 0.269568  [51200/60000]
loss: 0.374653  [57600/60000]
Test Error: 
 Accuracy: 92.5%, Avg loss: 0.262961 

Epoch 7
-------------------------------
loss: 0.195676  [    0/60000]
loss: 0.230322  [ 6400/60000]
loss: 0.164844  [12800/60000]
loss: 0.319183  [19200/60000]
loss: 0.198686  [25600/60000]
loss: 0.272288  [32000/60000]
loss: 0.157346  [38400/60000]
loss: 0.342784  [44800/60000]
loss: 0.242133  [51200/60000]
loss: 0.355913  [57600/60000]
Test Error: 
 Accuracy: 92.8%, Avg loss: 0.247321 

Epoch 8
-------------------------------
loss: 0.174269  [    0/60000]
loss: 0.219391  [ 6400/60000]
loss: 0.150918  [12800/60000]
loss: 0.305854  [19200/60000]
loss: 0.182818  [25600/60000]
loss: 0.261800  [32000/60000]
loss: 0.147795  [38400/60000]
loss: 0.325437  [44800/60000]
loss: 0.218672  [51200/60000]
loss: 0.337907  [57600/60000]
Test Error: 
 Accuracy: 93.2%, Avg loss: 0.233048 

Epoch 9
-------------------------------
loss: 0.157361  [    0/60000]
loss: 0.209914  [ 6400/60000]
loss: 0.139500  [12800/60000]
loss: 0.292533  [19200/60000]
loss: 0.169732  [25600/60000]
loss: 0.253424  [32000/60000]
loss: 0.138459  [38400/60000]
loss: 0.308890  [44800/60000]
loss: 0.198611  [51200/60000]
loss: 0.321007  [57600/60000]
Test Error: 
 Accuracy: 93.6%, Avg loss: 0.219827 

Epoch 10
-------------------------------
loss: 0.143458  [    0/60000]
loss: 0.200823  [ 6400/60000]
loss: 0.130008  [12800/60000]
loss: 0.279433  [19200/60000]
loss: 0.159168  [25600/60000]
loss: 0.246027  [32000/60000]
loss: 0.129360  [38400/60000]
loss: 0.293391  [44800/60000]
loss: 0.181554  [51200/60000]
loss: 0.305260  [57600/60000]
Test Error: 
 Accuracy: 93.9%, Avg loss: 0.207685 

Epoch 11
-------------------------------
loss: 0.131745  [    0/60000]
loss: 0.192202  [ 6400/60000]
loss: 0.121887  [12800/60000]
loss: 0.265702  [19200/60000]
loss: 0.150041  [25600/60000]
loss: 0.239148  [32000/60000]
loss: 0.120741  [38400/60000]
loss: 0.279369  [44800/60000]
loss: 0.168906  [51200/60000]
loss: 0.290600  [57600/60000]
Test Error: 
 Accuracy: 94.1%, Avg loss: 0.196576 

Epoch 12
-------------------------------
loss: 0.121396  [    0/60000]
loss: 0.184836  [ 6400/60000]
loss: 0.114914  [12800/60000]
loss: 0.252048  [19200/60000]
loss: 0.142015  [25600/60000]
loss: 0.232030  [32000/60000]
loss: 0.113253  [38400/60000]
loss: 0.266646  [44800/60000]
loss: 0.160878  [51200/60000]
loss: 0.278118  [57600/60000]
Test Error: 
 Accuracy: 94.5%, Avg loss: 0.186402 

Epoch 13
-------------------------------
loss: 0.112123  [    0/60000]
loss: 0.178441  [ 6400/60000]
loss: 0.108902  [12800/60000]
loss: 0.239651  [19200/60000]
loss: 0.134386  [25600/60000]
loss: 0.224691  [32000/60000]
loss: 0.106507  [38400/60000]
loss: 0.254671  [44800/60000]
loss: 0.155029  [51200/60000]
loss: 0.267427  [57600/60000]
Test Error: 
 Accuracy: 94.7%, Avg loss: 0.177047 

Epoch 14
-------------------------------
loss: 0.104444  [    0/60000]
loss: 0.172944  [ 6400/60000]
loss: 0.103483  [12800/60000]
loss: 0.227711  [19200/60000]
loss: 0.127313  [25600/60000]
loss: 0.216363  [32000/60000]
loss: 0.100120  [38400/60000]
loss: 0.243856  [44800/60000]
loss: 0.150323  [51200/60000]
loss: 0.257463  [57600/60000]
Test Error: 
 Accuracy: 94.9%, Avg loss: 0.168542 

Epoch 15
-------------------------------
loss: 0.097420  [    0/60000]
loss: 0.167871  [ 6400/60000]
loss: 0.098468  [12800/60000]
loss: 0.216863  [19200/60000]
loss: 0.120679  [25600/60000]
loss: 0.208238  [32000/60000]
loss: 0.094022  [38400/60000]
loss: 0.233274  [44800/60000]
loss: 0.146389  [51200/60000]
loss: 0.248148  [57600/60000]
Test Error: 
 Accuracy: 95.2%, Avg loss: 0.160844 

方向微分による訓練

同じ構造のモデルを2つ (model, model_v) 用意し,以下のような手順でパラメータを更新します.

  1. modelで順伝播を行い,lossを計算する.
  2. 勾配の推定値を保存するDictgrad_estimateを用意する.
  3. 摂動vを生成する.この際,modelのパラメータのkeyを辞書形式 model.state_dict()で取得し同じkeyで登録.同時にmodelと同じ構造のmodel_vのパラメータをmodelのパラメータに摂動vを加えたもので置換する.
  4. model_vで順伝播を行い,loss_vを計算する.
  5. 方向微分をloss_vlossを用いて有限差分で計算する.
  6. grad_estimateに勾配の推定値を加算する.
  7. num_directionの数だけ3-6を繰り返す.
  8. grad_estimateの値をnum_directionで平均化しtorch.clamp()でgradient clippingする (数値的に不安定なため).
  9. optimizer.zero_grad()modelのパラメータの勾配をzeroにする.
  10. param.gradに推定した勾配値を代入する.
  11. optimizer.step()でパラメータを更新する.

前節のシミュレーションではnum_directionsを増やさないと推定された勾配が真の勾配に近づきませんでしたが,num_directions=1でも学習は進行します.もちろん増やしてもいいですが,計算量が増えます.また,学習率lrは0.001とbackpropの0.01よりも小さいものを用いていますが,これはlrを0.01にすると発散したためです.

In [14]:
def train_nograd(dataloader, model, model_v, loss_fn, optimizer, mode="dodge", num_directions=1, 
                 eps=1e-3, grad_clip=1):
    size = len(dataloader.dataset)
    loss_list = []
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        with torch.no_grad():
            X, y = X.to(device), y.to(device)
            pred = model(X) # prediction
            loss = loss_fn(pred, y) # loss
            grad_estimate = {} # Dict for estimated gradient
            for i in range(num_directions):
                v = {} # Dict for perturbation
                for param_v, key in zip(model_v.parameters(), model.state_dict()):
                    if mode == "dodge":
                        v[key] = 2*(torch.rand(model.state_dict()[key].shape).to(device) > 0.5) - 1 
                    elif mode == "fgd":
                        v[key] = torch.randn(model.state_dict()[key].shape).to(device)
                    else:
                        assert False, "mode is dodge or fgd"

                    param_v.data = model.state_dict()[key] + eps*v[key] # substitution with perturbated param.

                pred_v = model_v(X) # perturbated prediction
                loss_v = loss_fn(pred_v, y) # perturbated loss
                jvp = (loss_v - loss) / eps  # directional derivative of loss at point params in direction v
            
                # gradient estimation
                for key in model.state_dict():
                    if key in grad_estimate:
                        grad_estimate[key] += jvp * v[key]
                    else:
                        grad_estimate[key] = jvp * v[key]

            # averaging & gradient clipping for estimated gradient
            for key in model.state_dict():
                grad_estimate[key] = torch.clamp(grad_estimate[key]/num_directions, -grad_clip, grad_clip)

            optimizer.zero_grad()
            # replacement of loss.backward()
            for param, key in zip(model.parameters(), model.state_dict()):
                param.grad = grad_estimate[key]        
            optimizer.step()
            
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            loss_list.append(loss)
    return np.array(loss_list)

DODGE

In [15]:
model_dodge, model_dodgev = NeuralNetwork().to(device), NeuralNetwork().to(device)
for param in model_dodge.parameters():
    param.requires_grad = False
for param in model_dodgev.parameters():
    param.requires_grad = False
#optimizer = torch.optim.Adam(model_dodge.parameters(), lr=1e-3)
optimizer_dodge = torch.optim.SGD(model_dodge.parameters(), lr=1e-3)

train_loss_dodge, test_loss_dodge = [], []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_nograd(train_dataloader, model_dodge, model_dodgev, loss_fn, optimizer_dodge, mode="dodge")
    test_loss = test(test_dataloader, model_dodge, loss_fn)
    train_loss_dodge.append(train_loss)
    test_loss_dodge.append(test_loss)
train_loss_dodge = np.concatenate(train_loss_dodge)
test_loss_dodge = np.array(test_loss_dodge)
Epoch 1
-------------------------------
loss: 2.293013  [    0/60000]
loss: 2.278605  [ 6400/60000]
loss: 2.305957  [12800/60000]
loss: 2.261985  [19200/60000]
loss: 2.278600  [25600/60000]
loss: 2.274848  [32000/60000]
loss: 2.264979  [38400/60000]
loss: 2.264772  [44800/60000]
loss: 2.243997  [51200/60000]
loss: 2.230047  [57600/60000]
Test Error: 
 Accuracy: 27.0%, Avg loss: 2.238401 

Epoch 2
-------------------------------
loss: 2.238425  [    0/60000]
loss: 2.213653  [ 6400/60000]
loss: 2.250131  [12800/60000]
loss: 2.172260  [19200/60000]
loss: 2.199060  [25600/60000]
loss: 2.191102  [32000/60000]
loss: 2.158249  [38400/60000]
loss: 2.203896  [44800/60000]
loss: 2.136338  [51200/60000]
loss: 2.109188  [57600/60000]
Test Error: 
 Accuracy: 51.2%, Avg loss: 2.119787 

Epoch 3
-------------------------------
loss: 2.115482  [    0/60000]
loss: 2.104982  [ 6400/60000]
loss: 2.143744  [12800/60000]
loss: 2.017807  [19200/60000]
loss: 2.061916  [25600/60000]
loss: 2.045247  [32000/60000]
loss: 2.010787  [38400/60000]
loss: 2.053323  [44800/60000]
loss: 1.954429  [51200/60000]
loss: 1.921283  [57600/60000]
Test Error: 
 Accuracy: 60.1%, Avg loss: 1.933119 

Epoch 4
-------------------------------
loss: 1.946108  [    0/60000]
loss: 1.871284  [ 6400/60000]
loss: 1.959304  [12800/60000]
loss: 1.783304  [19200/60000]
loss: 1.854810  [25600/60000]
loss: 1.832346  [32000/60000]
loss: 1.746651  [38400/60000]
loss: 1.863106  [44800/60000]
loss: 1.712567  [51200/60000]
loss: 1.637663  [57600/60000]
Test Error: 
 Accuracy: 63.5%, Avg loss: 1.674136 

Epoch 5
-------------------------------
loss: 1.716559  [    0/60000]
loss: 1.619153  [ 6400/60000]
loss: 1.696653  [12800/60000]
loss: 1.460011  [19200/60000]
loss: 1.512218  [25600/60000]
loss: 1.541011  [32000/60000]
loss: 1.445283  [38400/60000]
loss: 1.622717  [44800/60000]
loss: 1.419001  [51200/60000]
loss: 1.323012  [57600/60000]
Test Error: 
 Accuracy: 68.3%, Avg loss: 1.363791 

Epoch 6
-------------------------------
loss: 1.388289  [    0/60000]
loss: 1.309341  [ 6400/60000]
loss: 1.442154  [12800/60000]
loss: 1.223235  [19200/60000]
loss: 1.275618  [25600/60000]
loss: 1.317449  [32000/60000]
loss: 1.236379  [38400/60000]
loss: 1.375583  [44800/60000]
loss: 1.237414  [51200/60000]
loss: 1.128840  [57600/60000]
Test Error: 
 Accuracy: 70.2%, Avg loss: 1.175376 

Epoch 7
-------------------------------
loss: 1.216448  [    0/60000]
loss: 1.130485  [ 6400/60000]
loss: 1.316219  [12800/60000]
loss: 1.005424  [19200/60000]
loss: 1.101283  [25600/60000]
loss: 1.078783  [32000/60000]
loss: 0.955078  [38400/60000]
loss: 1.135801  [44800/60000]
loss: 1.091042  [51200/60000]
loss: 1.044853  [57600/60000]
Test Error: 
 Accuracy: 73.0%, Avg loss: 1.016321 

Epoch 8
-------------------------------
loss: 1.043064  [    0/60000]
loss: 1.013704  [ 6400/60000]
loss: 1.106562  [12800/60000]
loss: 0.749295  [19200/60000]
loss: 0.878776  [25600/60000]
loss: 1.041634  [32000/60000]
loss: 0.925950  [38400/60000]
loss: 1.054278  [44800/60000]
loss: 1.059773  [51200/60000]
loss: 0.884302  [57600/60000]
Test Error: 
 Accuracy: 75.0%, Avg loss: 0.929963 

Epoch 9
-------------------------------
loss: 0.957271  [    0/60000]
loss: 0.913348  [ 6400/60000]
loss: 1.025013  [12800/60000]
loss: 0.706904  [19200/60000]
loss: 0.794365  [25600/60000]
loss: 0.886365  [32000/60000]
loss: 0.795565  [38400/60000]
loss: 1.042791  [44800/60000]
loss: 0.918600  [51200/60000]
loss: 0.846945  [57600/60000]
Test Error: 
 Accuracy: 76.0%, Avg loss: 0.840766 

Epoch 10
-------------------------------
loss: 0.865723  [    0/60000]
loss: 0.866054  [ 6400/60000]
loss: 0.886908  [12800/60000]
loss: 0.670420  [19200/60000]
loss: 0.691289  [25600/60000]
loss: 0.806679  [32000/60000]
loss: 0.669625  [38400/60000]
loss: 0.989060  [44800/60000]
loss: 0.883575  [51200/60000]
loss: 0.833742  [57600/60000]
Test Error: 
 Accuracy: 77.1%, Avg loss: 0.779454 

Epoch 11
-------------------------------
loss: 0.862591  [    0/60000]
loss: 0.878451  [ 6400/60000]
loss: 0.820222  [12800/60000]
loss: 0.643583  [19200/60000]
loss: 0.638248  [25600/60000]
loss: 0.738340  [32000/60000]
loss: 0.646759  [38400/60000]
loss: 0.894081  [44800/60000]
loss: 0.858162  [51200/60000]
loss: 0.829792  [57600/60000]
Test Error: 
 Accuracy: 76.5%, Avg loss: 0.770023 

Epoch 12
-------------------------------
loss: 0.835574  [    0/60000]
loss: 0.877235  [ 6400/60000]
loss: 0.857782  [12800/60000]
loss: 0.690486  [19200/60000]
loss: 0.617188  [25600/60000]
loss: 0.711040  [32000/60000]
loss: 0.721151  [38400/60000]
loss: 0.966661  [44800/60000]
loss: 0.844607  [51200/60000]
loss: 0.791053  [57600/60000]
Test Error: 
 Accuracy: 76.8%, Avg loss: 0.756590 

Epoch 13
-------------------------------
loss: 0.799021  [    0/60000]
loss: 0.806324  [ 6400/60000]
loss: 0.743709  [12800/60000]
loss: 0.620406  [19200/60000]
loss: 0.595102  [25600/60000]
loss: 0.694497  [32000/60000]
loss: 0.610044  [38400/60000]
loss: 0.986340  [44800/60000]
loss: 0.834452  [51200/60000]
loss: 0.761270  [57600/60000]
Test Error: 
 Accuracy: 75.1%, Avg loss: 0.781446 

Epoch 14
-------------------------------
loss: 0.803079  [    0/60000]
loss: 0.924776  [ 6400/60000]
loss: 0.730276  [12800/60000]
loss: 0.619014  [19200/60000]
loss: 0.689474  [25600/60000]
loss: 0.743778  [32000/60000]
loss: 0.648899  [38400/60000]
loss: 0.996702  [44800/60000]
loss: 0.776125  [51200/60000]
loss: 0.803975  [57600/60000]
Test Error: 
 Accuracy: 75.8%, Avg loss: 0.755003 

Epoch 15
-------------------------------
loss: 0.811652  [    0/60000]
loss: 0.835664  [ 6400/60000]
loss: 0.662330  [12800/60000]
loss: 0.673836  [19200/60000]
loss: 0.564571  [25600/60000]
loss: 0.718223  [32000/60000]
loss: 0.634960  [38400/60000]
loss: 0.828920  [44800/60000]
loss: 0.767297  [51200/60000]
loss: 0.876423  [57600/60000]
Test Error: 
 Accuracy: 75.1%, Avg loss: 0.764193 

FGD

In [16]:
model_fgd, model_fgdv = NeuralNetwork().to(device), NeuralNetwork().to(device)
for param in model_fgd.parameters():
    param.requires_grad = False
for param in model_fgdv.parameters():
    param.requires_grad = False
optimizer_fgd = torch.optim.SGD(model_fgd.parameters(), lr=1e-3)

train_loss_fgd, test_loss_fgd = [], []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_nograd(train_dataloader, model_fgd, model_fgdv, loss_fn, optimizer_fgd, mode="fgd")
    test_loss = test(test_dataloader, model_fgd, loss_fn)
    train_loss_fgd.append(train_loss)
    test_loss_fgd.append(test_loss)
train_loss_fgd = np.concatenate(train_loss_fgd)
test_loss_fgd = np.array(test_loss_fgd)
Epoch 1
-------------------------------
loss: 2.296826  [    0/60000]
loss: 2.283787  [ 6400/60000]
loss: 2.293741  [12800/60000]
loss: 2.277387  [19200/60000]
loss: 2.291229  [25600/60000]
loss: 2.284855  [32000/60000]
loss: 2.269791  [38400/60000]
loss: 2.282986  [44800/60000]
loss: 2.260365  [51200/60000]
loss: 2.257126  [57600/60000]
Test Error: 
 Accuracy: 28.9%, Avg loss: 2.261620 

Epoch 2
-------------------------------
loss: 2.260387  [    0/60000]
loss: 2.243361  [ 6400/60000]
loss: 2.258446  [12800/60000]
loss: 2.217143  [19200/60000]
loss: 2.243744  [25600/60000]
loss: 2.238254  [32000/60000]
loss: 2.216753  [38400/60000]
loss: 2.238729  [44800/60000]
loss: 2.210622  [51200/60000]
loss: 2.196457  [57600/60000]
Test Error: 
 Accuracy: 31.1%, Avg loss: 2.202523 

Epoch 3
-------------------------------
loss: 2.204142  [    0/60000]
loss: 2.184506  [ 6400/60000]
loss: 2.217151  [12800/60000]
loss: 2.137058  [19200/60000]
loss: 2.168483  [25600/60000]
loss: 2.161743  [32000/60000]
loss: 2.135093  [38400/60000]
loss: 2.179393  [44800/60000]
loss: 2.139675  [51200/60000]
loss: 2.107305  [57600/60000]
Test Error: 
 Accuracy: 47.7%, Avg loss: 2.113572 

Epoch 4
-------------------------------
loss: 2.093888  [    0/60000]
loss: 2.081323  [ 6400/60000]
loss: 2.133465  [12800/60000]
loss: 2.013201  [19200/60000]
loss: 2.053054  [25600/60000]
loss: 2.050395  [32000/60000]
loss: 2.011332  [38400/60000]
loss: 2.068287  [44800/60000]
loss: 2.015212  [51200/60000]
loss: 1.967663  [57600/60000]
Test Error: 
 Accuracy: 53.5%, Avg loss: 1.981888 

Epoch 5
-------------------------------
loss: 1.974730  [    0/60000]
loss: 1.947753  [ 6400/60000]
loss: 1.988854  [12800/60000]
loss: 1.846123  [19200/60000]
loss: 1.906788  [25600/60000]
loss: 1.896905  [32000/60000]
loss: 1.839852  [38400/60000]
loss: 1.927831  [44800/60000]
loss: 1.821989  [51200/60000]
loss: 1.782795  [57600/60000]
Test Error: 
 Accuracy: 65.5%, Avg loss: 1.789055 

Epoch 6
-------------------------------
loss: 1.783425  [    0/60000]
loss: 1.762594  [ 6400/60000]
loss: 1.817015  [12800/60000]
loss: 1.638735  [19200/60000]
loss: 1.747962  [25600/60000]
loss: 1.718608  [32000/60000]
loss: 1.656324  [38400/60000]
loss: 1.748703  [44800/60000]
loss: 1.658347  [51200/60000]
loss: 1.596796  [57600/60000]
Test Error: 
 Accuracy: 68.5%, Avg loss: 1.607141 

Epoch 7
-------------------------------
loss: 1.634365  [    0/60000]
loss: 1.542256  [ 6400/60000]
loss: 1.632949  [12800/60000]
loss: 1.462924  [19200/60000]
loss: 1.525259  [25600/60000]
loss: 1.459265  [32000/60000]
loss: 1.416846  [38400/60000]
loss: 1.571794  [44800/60000]
loss: 1.409128  [51200/60000]
loss: 1.359775  [57600/60000]
Test Error: 
 Accuracy: 70.4%, Avg loss: 1.389032 

Epoch 8
-------------------------------
loss: 1.412317  [    0/60000]
loss: 1.306484  [ 6400/60000]
loss: 1.405024  [12800/60000]
loss: 1.290675  [19200/60000]
loss: 1.304878  [25600/60000]
loss: 1.247403  [32000/60000]
loss: 1.253526  [38400/60000]
loss: 1.363694  [44800/60000]
loss: 1.201337  [51200/60000]
loss: 1.205393  [57600/60000]
Test Error: 
 Accuracy: 74.7%, Avg loss: 1.200503 

Epoch 9
-------------------------------
loss: 1.217961  [    0/60000]
loss: 1.156513  [ 6400/60000]
loss: 1.214847  [12800/60000]
loss: 1.072343  [19200/60000]
loss: 1.085704  [25600/60000]
loss: 1.088957  [32000/60000]
loss: 1.117683  [38400/60000]
loss: 1.231151  [44800/60000]
loss: 1.047312  [51200/60000]
loss: 1.061814  [57600/60000]
Test Error: 
 Accuracy: 76.5%, Avg loss: 1.047569 

Epoch 10
-------------------------------
loss: 1.082913  [    0/60000]
loss: 0.924575  [ 6400/60000]
loss: 1.074471  [12800/60000]
loss: 0.963778  [19200/60000]
loss: 0.914800  [25600/60000]
loss: 0.905806  [32000/60000]
loss: 0.930664  [38400/60000]
loss: 1.089383  [44800/60000]
loss: 1.017955  [51200/60000]
loss: 0.919721  [57600/60000]
Test Error: 
 Accuracy: 78.4%, Avg loss: 0.918273 

Epoch 11
-------------------------------
loss: 0.951236  [    0/60000]
loss: 0.796322  [ 6400/60000]
loss: 0.907597  [12800/60000]
loss: 0.826737  [19200/60000]
loss: 0.859733  [25600/60000]
loss: 0.843981  [32000/60000]
loss: 0.836354  [38400/60000]
loss: 1.045860  [44800/60000]
loss: 0.936381  [51200/60000]
loss: 0.862719  [57600/60000]
Test Error: 
 Accuracy: 77.9%, Avg loss: 0.851654 

Epoch 12
-------------------------------
loss: 0.876074  [    0/60000]
loss: 0.755115  [ 6400/60000]
loss: 0.857361  [12800/60000]
loss: 0.767043  [19200/60000]
loss: 0.818982  [25600/60000]
loss: 0.817685  [32000/60000]
loss: 0.769299  [38400/60000]
loss: 0.890136  [44800/60000]
loss: 0.880792  [51200/60000]
loss: 0.812595  [57600/60000]
Test Error: 
 Accuracy: 78.6%, Avg loss: 0.781197 

Epoch 13
-------------------------------
loss: 0.823996  [    0/60000]
loss: 0.696730  [ 6400/60000]
loss: 0.779315  [12800/60000]
loss: 0.726534  [19200/60000]
loss: 0.784454  [25600/60000]
loss: 0.753779  [32000/60000]
loss: 0.698076  [38400/60000]
loss: 0.857928  [44800/60000]
loss: 0.906583  [51200/60000]
loss: 0.706785  [57600/60000]
Test Error: 
 Accuracy: 79.5%, Avg loss: 0.745699 

Epoch 14
-------------------------------
loss: 0.801857  [    0/60000]
loss: 0.696863  [ 6400/60000]
loss: 0.638847  [12800/60000]
loss: 0.805713  [19200/60000]
loss: 0.751706  [25600/60000]
loss: 0.675668  [32000/60000]
loss: 0.638079  [38400/60000]
loss: 0.893166  [44800/60000]
loss: 0.885967  [51200/60000]
loss: 0.713033  [57600/60000]
Test Error: 
 Accuracy: 78.5%, Avg loss: 0.720166 

Epoch 15
-------------------------------
loss: 0.735028  [    0/60000]
loss: 0.600313  [ 6400/60000]
loss: 0.638438  [12800/60000]
loss: 0.746406  [19200/60000]
loss: 0.738696  [25600/60000]
loss: 0.622368  [32000/60000]
loss: 0.647141  [38400/60000]
loss: 0.877152  [44800/60000]
loss: 0.862342  [51200/60000]
loss: 0.672747  [57600/60000]
Test Error: 
 Accuracy: 78.6%, Avg loss: 0.708830 

学習則の比較

最後に各学習則の訓練誤差を比較してみましょう.

In [17]:
iterations = np.arange(len(train_loss_bp)) * 100
plt.figure(figsize=(5, 4), dpi=100)
plt.plot(iterations, train_loss_bp, label="Backprop. w/ lr=1e-2")
plt.plot(iterations, train_loss_dodge, label="DODGE w/ lr=1e-3")
plt.plot(iterations, train_loss_fgd, label="FGD w/ lr=1e-3")
plt.xlabel("Iterations")
plt.ylabel("Train loss")
plt.legend()
plt.tight_layout()

学習初期の学習速度はDODGEの方が優れていましたが,最終的には同程度の損失・精度に収束しました.今回はPytorch自体のoptimizerを使用できるような実装にしましたが,Adamを用いてもテストデータに対する予測精度は変わりませんでした (~80%程度,確認にはoptimizaerの部分を変更してみてください).num_directionsを大きくすれば数値的に安定し,大きな学習率でも機能するかもしれません.あるいは学習率を下げて長い時間学習させればテストデータに対する予測精度は向上するかもしれません.いずれにせよ検証はしていないので興味がある方はやってみてください.DODGEもFGDも著者実装が公開されていないのですが,報告を読む限りはより効率的で安定した実装があるのかもしれません.