大阪大学医学部 Python会

Now is better than never.

方向微分によるニューラルネットワークの勾配近似

2022-02-20(Sun) - Posted by 山本 in 技術ブログ    tag:Python tag:Machine Learning tag:論文まとめ

Contents

    この記事では(Silver, et al., "Learning by Directional Gradient Descent." ICLR. 2021)および(Baydin, et al., "Gradients without Backpropagation", arXiv, 2022)の解説&実装を行います.

    両者とも摂動 (perturbation)方向微分(directional gradient) を用いて勾配を近似することで,誤差逆伝播法 (backpropagation)を用いずにニューラルネットワークを訓練するという手法を提案しています.gradient-free optimizationの一種とも言えるでしょう.以後,Silverらの提案手法をDODGE(Deep Online Directional Gradient Estimate), Baydinらの提案手法をFGD (Forward gradient descent)と呼ぶことにします.

    この手法の利点としては近似勾配の計算にニューラルネットワークの順伝播のみしか用いないため,パラメータを並列に更新することができる,ということが挙げられます.Baydinらは脳(神経回路網)の学習則にも触れていますが,biologicalなモデルに落とし込むとすれば次のように解釈できると思います:

    各シナプスにおいてランダムかつ微小なシナプス伝達強度の変化 (e.g. spine headの大きさ変化) の方向を"記憶"しておき,損失の方向微分の値というglobal factorを各シナプスにfeedbackした上で"記憶"しておいた微小なシナプス伝達強度変化に乗じて再度シナプス伝達強度を大きく変化させる.

    ただし,これはあくまで解釈です.この機構が実現可能かどうかの議論はこの記事ではしないことにします.

    下準備1: 摂動による学習法

    やや外れた話題ですが,先に摂動 (perturbation) による勾配を使用しない単純な学習法を紹介しておきます.ニューラルネットワークのパラメータを$\boldsymbol{\theta} \in \mathbb{R}^p$, データサンプルを$\mathbf{x}$,損失関数を$L(\boldsymbol{\theta}, \mathbf{x})$とします.また,パラメータへの摂動を$\mathbf{v}\in \mathbb{R}^p$とします.ここで単純な学習法とは,「パラメータに摂動を加えて損失が下がったらそのパラメータに更新する」です.

    $$ \begin{align*} &\Delta L = L(\boldsymbol{\theta}+\mathbf{v}, \mathbf{x}) - L(\boldsymbol{\theta}, \mathbf{x})\\ &\textbf{if}\ \Delta L < 0\ \text{:}\\ &\quad \boldsymbol{\theta} \leftarrow \boldsymbol{\theta}+\mathbf{v} \end{align*} $$

    一応学習は進みますが,効率的ではありません.誤差逆伝播法を用いない学習法の研究においてベースラインとしてよく用いられます.

    下準備2: 方向微分とJacobian-vector productの計算

    本記事で紹介する学習則では方向微分(directional gradient) というものが用いられます.関数$f$について点$\mathbf{u}$における方向$\mathbf{v}$の方向微分は

    $$ \nabla_\mathbf{v}f(\mathbf{u})= \lim_{h\to 0} \frac{f(\mathbf{u}+h\mathbf{v}) - f(\mathbf{u})}{h} $$

    として定義されます.また$f$が点$\mathbf{u}$において微分可能なら

    $$ \nabla_\mathbf{v}f(\mathbf{u})=\nabla f(\mathbf{u})\cdot \mathbf{v} $$

    が成り立ちます.ここで右辺をJacobian-vector product (JVP) と呼びます.JVPを計算する上でSilverらはForward Mode ADで計算できるjax.jvpを用いています.BaydinらはPytorch実装かつ自動微分部分は自前実装したようです.Pytorchにもtorch.autograd.functional.jvpがありますが,"double backwards trick"というbackwardsを2回用いる手法を用いているので勾配が必要になります.

    妥協案として有限差分(finite difference)を用いてJacobian-vector productを近似計算します ($\epsilon$は小さい値です).

    $$ \nabla f(\mathbf{u})\cdot \mathbf{v} \approx \frac{f(\mathbf{u}+\epsilon \mathbf{v}) - f(\mathbf{u})}{\epsilon} $$

    なお,$f(\mathbf{u})\in \mathbb{R}$の場合,$\nabla f(\mathbf{u})\cdot \mathbf{v}\in \mathbb{R}$となります.有限差分で近似計算ができることを簡単な関数 (cos)で確かめてみましょう.先に今回使うライブラリを全てimportしておきます.

    In [1]:
    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torchvision.transforms import ToTensor
    import numpy as np
    import matplotlib.pyplot as plt
    

    関数を定義します.今回は以下で定義されるBeale functionを用います.

    $$ f(x,y) = (1.5 - x+xy)^2+(2.25 - x+xy^2)^2+ (2.625 - x+xy^3)^2 $$
    In [2]:
    def func(x):
        return (1.5 - x[0]+x[0]*x[1])**2+(2.25 - x[0]+x[0]*x[1]**2)**2+ (2.625 - x[0]+x[0]*x[1]**3)**2
    

    方向微分を2通りで計算します.

    In [3]:
    u, v = torch.tensor([1.5, -0.1], requires_grad=True), torch.ones(2)
    # torch func.
    func_output, jvp = torch.autograd.functional.jvp(func, u, v)
    
    # finite difference
    eps = 1e-3
    f_v, f = func(u + eps*v), func(u)
    jvp_fd = (f_v - f) / eps
    print("torch.autograd.functional.jvp: ", jvp) 
    print("finite diff.:", jvp_fd)
    
    torch.autograd.functional.jvp:  tensor(-4.2418)
    finite diff.: tensor(-4.2384, grad_fn=<DivBackward0>)
    

    自動微分の結果と有限差分による結果が概ね一致していることがわかります.以後,有限差分を用いますが,今回紹介する学習則に適応する場合,2回forward計算が必要なので効率は悪いことに注意してください.

    摂動と方向微分による勾配近似

    それでは本題の説明をしましょう.まず,誤差逆伝播法と確率的勾配降下法(SGD)で学習する場合,誤差逆伝播法で$\nabla L(\boldsymbol{\theta}, \mathbf{x})=\dfrac{\partial L(\boldsymbol{\theta}, \mathbf{x})}{\partial \boldsymbol{\theta}}$を計算し,SGDにより

    $$ \boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta \cdot \nabla L(\boldsymbol{\theta}, \mathbf{x}) $$

    とパラメータを更新します.ただし,$\eta$は学習率です.一方,FGDとDODGEでは以下のようにパラメータを更新します.

    $$ \begin{align*} &\textbf{if}\ \text{DODGE:}\\ &\quad \mathbf{v} \sim \{-1, 1\}^p\\ &\textbf{else if}\ \text{FGD:}\\ &\quad \mathbf{v} \sim \mathcal{N}(0, \mathbf{I})\\ &g(\boldsymbol{\theta}, \mathbf{x}) = (\nabla L(\boldsymbol{\theta}, \mathbf{x})\cdot \mathbf{v})\cdot \mathbf{v}\\ &\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta \cdot g(\boldsymbol{\theta}, \mathbf{x}) \end{align*} $$

    2つの手法は摂動をサンプリングする分布が異なるだけと言えます.ここで$\nabla L(\boldsymbol{\theta}, \mathbf{x})\cdot \mathbf{v}$の計算の際に$\nabla L(\boldsymbol{\theta}, \mathbf{x})$を計算する必要がないことに注意してください.また,$g(\boldsymbol{\theta}, \mathbf{x})$が$\nabla L(\boldsymbol{\theta}, \mathbf{x})$の不偏推定量(unbiased estimator) になるということが最も重要な点です.これについての証明はそれぞれの論文に書いてあるのでそちらを参照してください.

    以下では数値計算を用いてこの手法で勾配が近似できることを説明します.先ほどの続きとしてcos関数のuにおける勾配を計算します.

    In [4]:
    grad = torch.autograd.grad(f, u)[0].numpy()
    print("True grad: ", grad)
    
    True grad:  [-3.433947   -0.80788505]
    

    勾配を推定する関数を作成します.

    In [5]:
    def grad_estimation(f, u, n, mode="dodge"):
        if mode == "dodge":
            v = 2*(torch.rand(n, 2) > 0.5) - 1   
        elif mode == "fgd":
            v = torch.randn(n, 2)
        else:
            assert False, "mode is dodge or fgd"
        estimate_grad = 0
        for i in range(n):
            _, jvp = torch.autograd.functional.jvp(f, u, v[i])
            estimate_grad += jvp*v[i]
        estimate_grad /= n
        return estimate_grad
    
    In [6]:
    num_directions = np.array([1, 10, 100, 1000])
    estimated_grad_dodge, estimated_grad_fgd  = np.zeros((len(num_directions), 2)), np.zeros((len(num_directions), 2))
    for i, n in enumerate(num_directions):
        estimated_grad_dodge[i] = grad_estimation(func, u, n, mode="dodge")
        estimated_grad_fgd[i] = grad_estimation(func, u, n, mode="fgd")
    

    結果を描画しましょう.DODGE, FGDのいずれにおいても摂動の数 (num_directions)を増やすことで推定された勾配が真の勾配に近づくことが分かります.複数回実行すればわかりますが,DODGEの方が推定の精度は良いです.(数学的に証明できると思います)

    In [7]:
    scale = 0.1
    px, py = u.detach().numpy()
    Xb = np.mgrid[-2:2:1000j, -2:2:1000j]
    Zb = func(Xb)
    
    titles = [r"$x$ grad.", r"$y$ grad."]; idx_dir = [0, -1]
    plt.figure(figsize=(6, 6), dpi=100)
    plt.suptitle("Gradients of Beale function")
    for i in range(2):
        plt.subplot(2,2,i+1)
        plt.title(titles[i])
        plt.semilogx(num_directions, estimated_grad_dodge[:, i], "o-", label="Estimated w/ DODGE")
        plt.semilogx(num_directions, estimated_grad_fgd[:, i], "o-", label="Estimated w/ FGD")
        plt.axhline(grad[i], linestyle="--", color="tab:red", label="Actual")
        plt.xlabel("Num. directions"); 
        if i == 0:
            plt.legend()
    
        plt.subplot(2,2,i+3)
        plt.title("Num. directions:"+str(num_directions[idx_dir[i]]))
        plt.contour(Xb[0], Xb[1], Zb, levels=50)
        plt.scatter(px, py, color="k")
        plt.arrow(px, py, estimated_grad_dodge[idx_dir[i], 0]*scale,  estimated_grad_dodge[idx_dir[i], 1]*scale, head_width=0.1, color='tab:blue', label="Estimated w/ DODGE")
        plt.arrow(px, py, estimated_grad_fgd[idx_dir[i], 0]*scale,  estimated_grad_fgd[idx_dir[i], 1]*scale, head_width=0.1, color='tab:orange', label="Estimated w/ FGD")
        plt.arrow(px, py, grad[0]*scale, grad[1]*scale, head_width=0.1, color='tab:red', label="Actual")
        plt.xlabel(r"$x$"); plt.ylabel(r"$y$") 
        if i == 0:
            plt.legend()
    plt.tight_layout()
    

    Pytorchによる実装

    ここから実装に入っていきます.MNISTデータセットをbackprop, DODGE, FGDで学習させます.このコードはpytorch tutorialを元に作成しました.

    In [8]:
    training_data = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
    test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())
    
    batch_size = 64
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    
    for X, y in test_dataloader:
        print(f"Shape of X [N, C, H, W]: {X.shape}")
        print(f"Shape of y: {y.shape} {y.dtype}")
        break
    
    Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
    Shape of y: torch.Size([64]) torch.int64
    
    In [9]:
    # Get cpu or gpu device for training.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    
    # Define model
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28*28, 512),
                nn.ReLU(),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Linear(512, 10)
            )
    
        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits
    
    Using cuda device
    

    誤差逆伝播法による訓練

    比較のため,勾配を正確に取得できる誤差逆伝播法を用いてモデルを訓練します.バッチサイズ64で15 epoch訓練します.

    In [10]:
    loss_fn = nn.CrossEntropyLoss()
    
    In [11]:
    def train(dataloader, model, loss_fn, optimizer):
        loss_list = []
        size = len(dataloader.dataset)
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
    
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)
    
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                loss_list.append(loss)
        return np.array(loss_list)
    
    In [12]:
    def test(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        return test_loss
    
    In [13]:
    model_bp = NeuralNetwork().to(device)
    optimizer_bp = torch.optim.SGD(model_bp.parameters(), lr=1e-2)
    train_loss_bp, test_loss_bp = [], []
    epochs = 15
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = train(train_dataloader, model_bp, loss_fn, optimizer_bp)
        test_loss = test(test_dataloader, model_bp, loss_fn)
        train_loss_bp.append(train_loss)
        test_loss_bp.append(test_loss)
    train_loss_bp = np.concatenate(train_loss_bp)
    test_loss_bp = np.array(test_loss_bp)
    
    Epoch 1
    -------------------------------
    loss: 2.299546  [    0/60000]
    loss: 2.243087  [ 6400/60000]
    loss: 2.193906  [12800/60000]
    loss: 1.990515  [19200/60000]
    loss: 1.825056  [25600/60000]
    loss: 1.497010  [32000/60000]
    loss: 1.088965  [38400/60000]
    loss: 1.071147  [44800/60000]
    loss: 0.839060  [51200/60000]
    loss: 0.702301  [57600/60000]
    Test Error: 
     Accuracy: 83.7%, Avg loss: 0.645017 
    
    Epoch 2
    -------------------------------
    loss: 0.738222  [    0/60000]
    loss: 0.536721  [ 6400/60000]
    loss: 0.543878  [12800/60000]
    loss: 0.521499  [19200/60000]
    loss: 0.463977  [25600/60000]
    loss: 0.448645  [32000/60000]
    loss: 0.326735  [38400/60000]
    loss: 0.522257  [44800/60000]
    loss: 0.468751  [51200/60000]
    loss: 0.475883  [57600/60000]
    Test Error: 
     Accuracy: 88.9%, Avg loss: 0.391872 
    
    Epoch 3
    -------------------------------
    loss: 0.428448  [    0/60000]
    loss: 0.321109  [ 6400/60000]
    loss: 0.325715  [12800/60000]
    loss: 0.404429  [19200/60000]
    loss: 0.327921  [25600/60000]
    loss: 0.372084  [32000/60000]
    loss: 0.230002  [38400/60000]
    loss: 0.442992  [44800/60000]
    loss: 0.386030  [51200/60000]
    loss: 0.438659  [57600/60000]
    Test Error: 
     Accuracy: 90.4%, Avg loss: 0.332632 
    
    Epoch 4
    -------------------------------
    loss: 0.320814  [    0/60000]
    loss: 0.276162  [ 6400/60000]
    loss: 0.247757  [12800/60000]
    loss: 0.369032  [19200/60000]
    loss: 0.275181  [25600/60000]
    loss: 0.333127  [32000/60000]
    loss: 0.196659  [38400/60000]
    loss: 0.405020  [44800/60000]
    loss: 0.338511  [51200/60000]
    loss: 0.415403  [57600/60000]
    Test Error: 
     Accuracy: 91.3%, Avg loss: 0.302255 
    
    Epoch 5
    -------------------------------
    loss: 0.263066  [    0/60000]
    loss: 0.256540  [ 6400/60000]
    loss: 0.207443  [12800/60000]
    loss: 0.350262  [19200/60000]
    loss: 0.241405  [25600/60000]
    loss: 0.307048  [32000/60000]
    loss: 0.179711  [38400/60000]
    loss: 0.379755  [44800/60000]
    loss: 0.300815  [51200/60000]
    loss: 0.394544  [57600/60000]
    Test Error: 
     Accuracy: 92.0%, Avg loss: 0.280586 
    
    Epoch 6
    -------------------------------
    loss: 0.223930  [    0/60000]
    loss: 0.242454  [ 6400/60000]
    loss: 0.182857  [12800/60000]
    loss: 0.334236  [19200/60000]
    loss: 0.217351  [25600/60000]
    loss: 0.287452  [32000/60000]
    loss: 0.167590  [38400/60000]
    loss: 0.359716  [44800/60000]
    loss: 0.269568  [51200/60000]
    loss: 0.374653  [57600/60000]
    Test Error: 
     Accuracy: 92.5%, Avg loss: 0.262961 
    
    Epoch 7
    -------------------------------
    loss: 0.195676  [    0/60000]
    loss: 0.230322  [ 6400/60000]
    loss: 0.164844  [12800/60000]
    loss: 0.319183  [19200/60000]
    loss: 0.198686  [25600/60000]
    loss: 0.272288  [32000/60000]
    loss: 0.157346  [38400/60000]
    loss: 0.342784  [44800/60000]
    loss: 0.242133  [51200/60000]
    loss: 0.355913  [57600/60000]
    Test Error: 
     Accuracy: 92.8%, Avg loss: 0.247321 
    
    Epoch 8
    -------------------------------
    loss: 0.174269  [    0/60000]
    loss: 0.219391  [ 6400/60000]
    loss: 0.150918  [12800/60000]
    loss: 0.305854  [19200/60000]
    loss: 0.182818  [25600/60000]
    loss: 0.261800  [32000/60000]
    loss: 0.147795  [38400/60000]
    loss: 0.325437  [44800/60000]
    loss: 0.218672  [51200/60000]
    loss: 0.337907  [57600/60000]
    Test Error: 
     Accuracy: 93.2%, Avg loss: 0.233048 
    
    Epoch 9
    -------------------------------
    loss: 0.157361  [    0/60000]
    loss: 0.209914  [ 6400/60000]
    loss: 0.139500  [12800/60000]
    loss: 0.292533  [19200/60000]
    loss: 0.169732  [25600/60000]
    loss: 0.253424  [32000/60000]
    loss: 0.138459  [38400/60000]
    loss: 0.308890  [44800/60000]
    loss: 0.198611  [51200/60000]
    loss: 0.321007  [57600/60000]
    Test Error: 
     Accuracy: 93.6%, Avg loss: 0.219827 
    
    Epoch 10
    -------------------------------
    loss: 0.143458  [    0/60000]
    loss: 0.200823  [ 6400/60000]
    loss: 0.130008  [12800/60000]
    loss: 0.279433  [19200/60000]
    loss: 0.159168  [25600/60000]
    loss: 0.246027  [32000/60000]
    loss: 0.129360  [38400/60000]
    loss: 0.293391  [44800/60000]
    loss: 0.181554  [51200/60000]
    loss: 0.305260  [57600/60000]
    Test Error: 
     Accuracy: 93.9%, Avg loss: 0.207685 
    
    Epoch 11
    -------------------------------
    loss: 0.131745  [    0/60000]
    loss: 0.192202  [ 6400/60000]
    loss: 0.121887  [12800/60000]
    loss: 0.265702  [19200/60000]
    loss: 0.150041  [25600/60000]
    loss: 0.239148  [32000/60000]
    loss: 0.120741  [38400/60000]
    loss: 0.279369  [44800/60000]
    loss: 0.168906  [51200/60000]
    loss: 0.290600  [57600/60000]
    Test Error: 
     Accuracy: 94.1%, Avg loss: 0.196576 
    
    Epoch 12
    -------------------------------
    loss: 0.121396  [    0/60000]
    loss: 0.184836  [ 6400/60000]
    loss: 0.114914  [12800/60000]
    loss: 0.252048  [19200/60000]
    loss: 0.142015  [25600/60000]
    loss: 0.232030  [32000/60000]
    loss: 0.113253  [38400/60000]
    loss: 0.266646  [44800/60000]
    loss: 0.160878  [51200/60000]
    loss: 0.278118  [57600/60000]
    Test Error: 
     Accuracy: 94.5%, Avg loss: 0.186402 
    
    Epoch 13
    -------------------------------
    loss: 0.112123  [    0/60000]
    loss: 0.178441  [ 6400/60000]
    loss: 0.108902  [12800/60000]
    loss: 0.239651  [19200/60000]
    loss: 0.134386  [25600/60000]
    loss: 0.224691  [32000/60000]
    loss: 0.106507  [38400/60000]
    loss: 0.254671  [44800/60000]
    loss: 0.155029  [51200/60000]
    loss: 0.267427  [57600/60000]
    Test Error: 
     Accuracy: 94.7%, Avg loss: 0.177047 
    
    Epoch 14
    -------------------------------
    loss: 0.104444  [    0/60000]
    loss: 0.172944  [ 6400/60000]
    loss: 0.103483  [12800/60000]
    loss: 0.227711  [19200/60000]
    loss: 0.127313  [25600/60000]
    loss: 0.216363  [32000/60000]
    loss: 0.100120  [38400/60000]
    loss: 0.243856  [44800/60000]
    loss: 0.150323  [51200/60000]
    loss: 0.257463  [57600/60000]
    Test Error: 
     Accuracy: 94.9%, Avg loss: 0.168542 
    
    Epoch 15
    -------------------------------
    loss: 0.097420  [    0/60000]
    loss: 0.167871  [ 6400/60000]
    loss: 0.098468  [12800/60000]
    loss: 0.216863  [19200/60000]
    loss: 0.120679  [25600/60000]
    loss: 0.208238  [32000/60000]
    loss: 0.094022  [38400/60000]
    loss: 0.233274  [44800/60000]
    loss: 0.146389  [51200/60000]
    loss: 0.248148  [57600/60000]
    Test Error: 
     Accuracy: 95.2%, Avg loss: 0.160844 
    
    

    方向微分による訓練

    同じ構造のモデルを2つ (model, model_v) 用意し,以下のような手順でパラメータを更新します.

    1. modelで順伝播を行い,lossを計算する.
    2. 勾配の推定値を保存するDictgrad_estimateを用意する.
    3. 摂動vを生成する.この際,modelのパラメータのkeyを辞書形式 model.state_dict()で取得し同じkeyで登録.同時にmodelと同じ構造のmodel_vのパラメータをmodelのパラメータに摂動vを加えたもので置換する.
    4. model_vで順伝播を行い,loss_vを計算する.
    5. 方向微分をloss_vlossを用いて有限差分で計算する.
    6. grad_estimateに勾配の推定値を加算する.
    7. num_directionの数だけ3-6を繰り返す.
    8. grad_estimateの値をnum_directionで平均化しtorch.clamp()でgradient clippingする (数値的に不安定なため).
    9. optimizer.zero_grad()modelのパラメータの勾配をzeroにする.
    10. param.gradに推定した勾配値を代入する.
    11. optimizer.step()でパラメータを更新する.

    前節のシミュレーションではnum_directionsを増やさないと推定された勾配が真の勾配に近づきませんでしたが,num_directions=1でも学習は進行します.もちろん増やしてもいいですが,計算量が増えます.また,学習率$lr$は0.001とbackpropの0.01よりも小さいものを用いていますが,これはlrを0.01にすると発散したためです.

    In [14]:
    def train_nograd(dataloader, model, model_v, loss_fn, optimizer, mode="dodge", num_directions=1, 
                     eps=1e-3, grad_clip=1):
        size = len(dataloader.dataset)
        loss_list = []
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            with torch.no_grad():
                X, y = X.to(device), y.to(device)
                pred = model(X) # prediction
                loss = loss_fn(pred, y) # loss
                grad_estimate = {} # Dict for estimated gradient
                for i in range(num_directions):
                    v = {} # Dict for perturbation
                    for param_v, key in zip(model_v.parameters(), model.state_dict()):
                        if mode == "dodge":
                            v[key] = 2*(torch.rand(model.state_dict()[key].shape).to(device) > 0.5) - 1 
                        elif mode == "fgd":
                            v[key] = torch.randn(model.state_dict()[key].shape).to(device)
                        else:
                            assert False, "mode is dodge or fgd"
    
                        param_v.data = model.state_dict()[key] + eps*v[key] # substitution with perturbated param.
    
                    pred_v = model_v(X) # perturbated prediction
                    loss_v = loss_fn(pred_v, y) # perturbated loss
                    jvp = (loss_v - loss) / eps  # directional derivative of loss at point params in direction v
                
                    # gradient estimation
                    for key in model.state_dict():
                        if key in grad_estimate:
                            grad_estimate[key] += jvp * v[key]
                        else:
                            grad_estimate[key] = jvp * v[key]
    
                # averaging & gradient clipping for estimated gradient
                for key in model.state_dict():
                    grad_estimate[key] = torch.clamp(grad_estimate[key]/num_directions, -grad_clip, grad_clip)
    
                optimizer.zero_grad()
                # replacement of loss.backward()
                for param, key in zip(model.parameters(), model.state_dict()):
                    param.grad = grad_estimate[key]        
                optimizer.step()
                
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                loss_list.append(loss)
        return np.array(loss_list)
    

    DODGE

    In [15]:
    model_dodge, model_dodgev = NeuralNetwork().to(device), NeuralNetwork().to(device)
    for param in model_dodge.parameters():
        param.requires_grad = False
    for param in model_dodgev.parameters():
        param.requires_grad = False
    #optimizer = torch.optim.Adam(model_dodge.parameters(), lr=1e-3)
    optimizer_dodge = torch.optim.SGD(model_dodge.parameters(), lr=1e-3)
    
    train_loss_dodge, test_loss_dodge = [], []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = train_nograd(train_dataloader, model_dodge, model_dodgev, loss_fn, optimizer_dodge, mode="dodge")
        test_loss = test(test_dataloader, model_dodge, loss_fn)
        train_loss_dodge.append(train_loss)
        test_loss_dodge.append(test_loss)
    train_loss_dodge = np.concatenate(train_loss_dodge)
    test_loss_dodge = np.array(test_loss_dodge)
    
    Epoch 1
    -------------------------------
    loss: 2.293013  [    0/60000]
    loss: 2.278605  [ 6400/60000]
    loss: 2.305957  [12800/60000]
    loss: 2.261985  [19200/60000]
    loss: 2.278600  [25600/60000]
    loss: 2.274848  [32000/60000]
    loss: 2.264979  [38400/60000]
    loss: 2.264772  [44800/60000]
    loss: 2.243997  [51200/60000]
    loss: 2.230047  [57600/60000]
    Test Error: 
     Accuracy: 27.0%, Avg loss: 2.238401 
    
    Epoch 2
    -------------------------------
    loss: 2.238425  [    0/60000]
    loss: 2.213653  [ 6400/60000]
    loss: 2.250131  [12800/60000]
    loss: 2.172260  [19200/60000]
    loss: 2.199060  [25600/60000]
    loss: 2.191102  [32000/60000]
    loss: 2.158249  [38400/60000]
    loss: 2.203896  [44800/60000]
    loss: 2.136338  [51200/60000]
    loss: 2.109188  [57600/60000]
    Test Error: 
     Accuracy: 51.2%, Avg loss: 2.119787 
    
    Epoch 3
    -------------------------------
    loss: 2.115482  [    0/60000]
    loss: 2.104982  [ 6400/60000]
    loss: 2.143744  [12800/60000]
    loss: 2.017807  [19200/60000]
    loss: 2.061916  [25600/60000]
    loss: 2.045247  [32000/60000]
    loss: 2.010787  [38400/60000]
    loss: 2.053323  [44800/60000]
    loss: 1.954429  [51200/60000]
    loss: 1.921283  [57600/60000]
    Test Error: 
     Accuracy: 60.1%, Avg loss: 1.933119 
    
    Epoch 4
    -------------------------------
    loss: 1.946108  [    0/60000]
    loss: 1.871284  [ 6400/60000]
    loss: 1.959304  [12800/60000]
    loss: 1.783304  [19200/60000]
    loss: 1.854810  [25600/60000]
    loss: 1.832346  [32000/60000]
    loss: 1.746651  [38400/60000]
    loss: 1.863106  [44800/60000]
    loss: 1.712567  [51200/60000]
    loss: 1.637663  [57600/60000]
    Test Error: 
     Accuracy: 63.5%, Avg loss: 1.674136 
    
    Epoch 5
    -------------------------------
    loss: 1.716559  [    0/60000]
    loss: 1.619153  [ 6400/60000]
    loss: 1.696653  [12800/60000]
    loss: 1.460011  [19200/60000]
    loss: 1.512218  [25600/60000]
    loss: 1.541011  [32000/60000]
    loss: 1.445283  [38400/60000]
    loss: 1.622717  [44800/60000]
    loss: 1.419001  [51200/60000]
    loss: 1.323012  [57600/60000]
    Test Error: 
     Accuracy: 68.3%, Avg loss: 1.363791 
    
    Epoch 6
    -------------------------------
    loss: 1.388289  [    0/60000]
    loss: 1.309341  [ 6400/60000]
    loss: 1.442154  [12800/60000]
    loss: 1.223235  [19200/60000]
    loss: 1.275618  [25600/60000]
    loss: 1.317449  [32000/60000]
    loss: 1.236379  [38400/60000]
    loss: 1.375583  [44800/60000]
    loss: 1.237414  [51200/60000]
    loss: 1.128840  [57600/60000]
    Test Error: 
     Accuracy: 70.2%, Avg loss: 1.175376 
    
    Epoch 7
    -------------------------------
    loss: 1.216448  [    0/60000]
    loss: 1.130485  [ 6400/60000]
    loss: 1.316219  [12800/60000]
    loss: 1.005424  [19200/60000]
    loss: 1.101283  [25600/60000]
    loss: 1.078783  [32000/60000]
    loss: 0.955078  [38400/60000]
    loss: 1.135801  [44800/60000]
    loss: 1.091042  [51200/60000]
    loss: 1.044853  [57600/60000]
    Test Error: 
     Accuracy: 73.0%, Avg loss: 1.016321 
    
    Epoch 8
    -------------------------------
    loss: 1.043064  [    0/60000]
    loss: 1.013704  [ 6400/60000]
    loss: 1.106562  [12800/60000]
    loss: 0.749295  [19200/60000]
    loss: 0.878776  [25600/60000]
    loss: 1.041634  [32000/60000]
    loss: 0.925950  [38400/60000]
    loss: 1.054278  [44800/60000]
    loss: 1.059773  [51200/60000]
    loss: 0.884302  [57600/60000]
    Test Error: 
     Accuracy: 75.0%, Avg loss: 0.929963 
    
    Epoch 9
    -------------------------------
    loss: 0.957271  [    0/60000]
    loss: 0.913348  [ 6400/60000]
    loss: 1.025013  [12800/60000]
    loss: 0.706904  [19200/60000]
    loss: 0.794365  [25600/60000]
    loss: 0.886365  [32000/60000]
    loss: 0.795565  [38400/60000]
    loss: 1.042791  [44800/60000]
    loss: 0.918600  [51200/60000]
    loss: 0.846945  [57600/60000]
    Test Error: 
     Accuracy: 76.0%, Avg loss: 0.840766 
    
    Epoch 10
    -------------------------------
    loss: 0.865723  [    0/60000]
    loss: 0.866054  [ 6400/60000]
    loss: 0.886908  [12800/60000]
    loss: 0.670420  [19200/60000]
    loss: 0.691289  [25600/60000]
    loss: 0.806679  [32000/60000]
    loss: 0.669625  [38400/60000]
    loss: 0.989060  [44800/60000]
    loss: 0.883575  [51200/60000]
    loss: 0.833742  [57600/60000]
    Test Error: 
     Accuracy: 77.1%, Avg loss: 0.779454 
    
    Epoch 11
    -------------------------------
    loss: 0.862591  [    0/60000]
    loss: 0.878451  [ 6400/60000]
    loss: 0.820222  [12800/60000]
    loss: 0.643583  [19200/60000]
    loss: 0.638248  [25600/60000]
    loss: 0.738340  [32000/60000]
    loss: 0.646759  [38400/60000]
    loss: 0.894081  [44800/60000]
    loss: 0.858162  [51200/60000]
    loss: 0.829792  [57600/60000]
    Test Error: 
     Accuracy: 76.5%, Avg loss: 0.770023 
    
    Epoch 12
    -------------------------------
    loss: 0.835574  [    0/60000]
    loss: 0.877235  [ 6400/60000]
    loss: 0.857782  [12800/60000]
    loss: 0.690486  [19200/60000]
    loss: 0.617188  [25600/60000]
    loss: 0.711040  [32000/60000]
    loss: 0.721151  [38400/60000]
    loss: 0.966661  [44800/60000]
    loss: 0.844607  [51200/60000]
    loss: 0.791053  [57600/60000]
    Test Error: 
     Accuracy: 76.8%, Avg loss: 0.756590 
    
    Epoch 13
    -------------------------------
    loss: 0.799021  [    0/60000]
    loss: 0.806324  [ 6400/60000]
    loss: 0.743709  [12800/60000]
    loss: 0.620406  [19200/60000]
    loss: 0.595102  [25600/60000]
    loss: 0.694497  [32000/60000]
    loss: 0.610044  [38400/60000]
    loss: 0.986340  [44800/60000]
    loss: 0.834452  [51200/60000]
    loss: 0.761270  [57600/60000]
    Test Error: 
     Accuracy: 75.1%, Avg loss: 0.781446 
    
    Epoch 14
    -------------------------------
    loss: 0.803079  [    0/60000]
    loss: 0.924776  [ 6400/60000]
    loss: 0.730276  [12800/60000]
    loss: 0.619014  [19200/60000]
    loss: 0.689474  [25600/60000]
    loss: 0.743778  [32000/60000]
    loss: 0.648899  [38400/60000]
    loss: 0.996702  [44800/60000]
    loss: 0.776125  [51200/60000]
    loss: 0.803975  [57600/60000]
    Test Error: 
     Accuracy: 75.8%, Avg loss: 0.755003 
    
    Epoch 15
    -------------------------------
    loss: 0.811652  [    0/60000]
    loss: 0.835664  [ 6400/60000]
    loss: 0.662330  [12800/60000]
    loss: 0.673836  [19200/60000]
    loss: 0.564571  [25600/60000]
    loss: 0.718223  [32000/60000]
    loss: 0.634960  [38400/60000]
    loss: 0.828920  [44800/60000]
    loss: 0.767297  [51200/60000]
    loss: 0.876423  [57600/60000]
    Test Error: 
     Accuracy: 75.1%, Avg loss: 0.764193 
    
    

    FGD

    In [16]:
    model_fgd, model_fgdv = NeuralNetwork().to(device), NeuralNetwork().to(device)
    for param in model_fgd.parameters():
        param.requires_grad = False
    for param in model_fgdv.parameters():
        param.requires_grad = False
    optimizer_fgd = torch.optim.SGD(model_fgd.parameters(), lr=1e-3)
    
    train_loss_fgd, test_loss_fgd = [], []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = train_nograd(train_dataloader, model_fgd, model_fgdv, loss_fn, optimizer_fgd, mode="fgd")
        test_loss = test(test_dataloader, model_fgd, loss_fn)
        train_loss_fgd.append(train_loss)
        test_loss_fgd.append(test_loss)
    train_loss_fgd = np.concatenate(train_loss_fgd)
    test_loss_fgd = np.array(test_loss_fgd)
    
    Epoch 1
    -------------------------------
    loss: 2.296826  [    0/60000]
    loss: 2.283787  [ 6400/60000]
    loss: 2.293741  [12800/60000]
    loss: 2.277387  [19200/60000]
    loss: 2.291229  [25600/60000]
    loss: 2.284855  [32000/60000]
    loss: 2.269791  [38400/60000]
    loss: 2.282986  [44800/60000]
    loss: 2.260365  [51200/60000]
    loss: 2.257126  [57600/60000]
    Test Error: 
     Accuracy: 28.9%, Avg loss: 2.261620 
    
    Epoch 2
    -------------------------------
    loss: 2.260387  [    0/60000]
    loss: 2.243361  [ 6400/60000]
    loss: 2.258446  [12800/60000]
    loss: 2.217143  [19200/60000]
    loss: 2.243744  [25600/60000]
    loss: 2.238254  [32000/60000]
    loss: 2.216753  [38400/60000]
    loss: 2.238729  [44800/60000]
    loss: 2.210622  [51200/60000]
    loss: 2.196457  [57600/60000]
    Test Error: 
     Accuracy: 31.1%, Avg loss: 2.202523 
    
    Epoch 3
    -------------------------------
    loss: 2.204142  [    0/60000]
    loss: 2.184506  [ 6400/60000]
    loss: 2.217151  [12800/60000]
    loss: 2.137058  [19200/60000]
    loss: 2.168483  [25600/60000]
    loss: 2.161743  [32000/60000]
    loss: 2.135093  [38400/60000]
    loss: 2.179393  [44800/60000]
    loss: 2.139675  [51200/60000]
    loss: 2.107305  [57600/60000]
    Test Error: 
     Accuracy: 47.7%, Avg loss: 2.113572 
    
    Epoch 4
    -------------------------------
    loss: 2.093888  [    0/60000]
    loss: 2.081323  [ 6400/60000]
    loss: 2.133465  [12800/60000]
    loss: 2.013201  [19200/60000]
    loss: 2.053054  [25600/60000]
    loss: 2.050395  [32000/60000]
    loss: 2.011332  [38400/60000]
    loss: 2.068287  [44800/60000]
    loss: 2.015212  [51200/60000]
    loss: 1.967663  [57600/60000]
    Test Error: 
     Accuracy: 53.5%, Avg loss: 1.981888 
    
    Epoch 5
    -------------------------------
    loss: 1.974730  [    0/60000]
    loss: 1.947753  [ 6400/60000]
    loss: 1.988854  [12800/60000]
    loss: 1.846123  [19200/60000]
    loss: 1.906788  [25600/60000]
    loss: 1.896905  [32000/60000]
    loss: 1.839852  [38400/60000]
    loss: 1.927831  [44800/60000]
    loss: 1.821989  [51200/60000]
    loss: 1.782795  [57600/60000]
    Test Error: 
     Accuracy: 65.5%, Avg loss: 1.789055 
    
    Epoch 6
    -------------------------------
    loss: 1.783425  [    0/60000]
    loss: 1.762594  [ 6400/60000]
    loss: 1.817015  [12800/60000]
    loss: 1.638735  [19200/60000]
    loss: 1.747962  [25600/60000]
    loss: 1.718608  [32000/60000]
    loss: 1.656324  [38400/60000]
    loss: 1.748703  [44800/60000]
    loss: 1.658347  [51200/60000]
    loss: 1.596796  [57600/60000]
    Test Error: 
     Accuracy: 68.5%, Avg loss: 1.607141 
    
    Epoch 7
    -------------------------------
    loss: 1.634365  [    0/60000]
    loss: 1.542256  [ 6400/60000]
    loss: 1.632949  [12800/60000]
    loss: 1.462924  [19200/60000]
    loss: 1.525259  [25600/60000]
    loss: 1.459265  [32000/60000]
    loss: 1.416846  [38400/60000]
    loss: 1.571794  [44800/60000]
    loss: 1.409128  [51200/60000]
    loss: 1.359775  [57600/60000]
    Test Error: 
     Accuracy: 70.4%, Avg loss: 1.389032 
    
    Epoch 8
    -------------------------------
    loss: 1.412317  [    0/60000]
    loss: 1.306484  [ 6400/60000]
    loss: 1.405024  [12800/60000]
    loss: 1.290675  [19200/60000]
    loss: 1.304878  [25600/60000]
    loss: 1.247403  [32000/60000]
    loss: 1.253526  [38400/60000]
    loss: 1.363694  [44800/60000]
    loss: 1.201337  [51200/60000]
    loss: 1.205393  [57600/60000]
    Test Error: 
     Accuracy: 74.7%, Avg loss: 1.200503 
    
    Epoch 9
    -------------------------------
    loss: 1.217961  [    0/60000]
    loss: 1.156513  [ 6400/60000]
    loss: 1.214847  [12800/60000]
    loss: 1.072343  [19200/60000]
    loss: 1.085704  [25600/60000]
    loss: 1.088957  [32000/60000]
    loss: 1.117683  [38400/60000]
    loss: 1.231151  [44800/60000]
    loss: 1.047312  [51200/60000]
    loss: 1.061814  [57600/60000]
    Test Error: 
     Accuracy: 76.5%, Avg loss: 1.047569 
    
    Epoch 10
    -------------------------------
    loss: 1.082913  [    0/60000]
    loss: 0.924575  [ 6400/60000]
    loss: 1.074471  [12800/60000]
    loss: 0.963778  [19200/60000]
    loss: 0.914800  [25600/60000]
    loss: 0.905806  [32000/60000]
    loss: 0.930664  [38400/60000]
    loss: 1.089383  [44800/60000]
    loss: 1.017955  [51200/60000]
    loss: 0.919721  [57600/60000]
    Test Error: 
     Accuracy: 78.4%, Avg loss: 0.918273 
    
    Epoch 11
    -------------------------------
    loss: 0.951236  [    0/60000]
    loss: 0.796322  [ 6400/60000]
    loss: 0.907597  [12800/60000]
    loss: 0.826737  [19200/60000]
    loss: 0.859733  [25600/60000]
    loss: 0.843981  [32000/60000]
    loss: 0.836354  [38400/60000]
    loss: 1.045860  [44800/60000]
    loss: 0.936381  [51200/60000]
    loss: 0.862719  [57600/60000]
    Test Error: 
     Accuracy: 77.9%, Avg loss: 0.851654 
    
    Epoch 12
    -------------------------------
    loss: 0.876074  [    0/60000]
    loss: 0.755115  [ 6400/60000]
    loss: 0.857361  [12800/60000]
    loss: 0.767043  [19200/60000]
    loss: 0.818982  [25600/60000]
    loss: 0.817685  [32000/60000]
    loss: 0.769299  [38400/60000]
    loss: 0.890136  [44800/60000]
    loss: 0.880792  [51200/60000]
    loss: 0.812595  [57600/60000]
    Test Error: 
     Accuracy: 78.6%, Avg loss: 0.781197 
    
    Epoch 13
    -------------------------------
    loss: 0.823996  [    0/60000]
    loss: 0.696730  [ 6400/60000]
    loss: 0.779315  [12800/60000]
    loss: 0.726534  [19200/60000]
    loss: 0.784454  [25600/60000]
    loss: 0.753779  [32000/60000]
    loss: 0.698076  [38400/60000]
    loss: 0.857928  [44800/60000]
    loss: 0.906583  [51200/60000]
    loss: 0.706785  [57600/60000]
    Test Error: 
     Accuracy: 79.5%, Avg loss: 0.745699 
    
    Epoch 14
    -------------------------------
    loss: 0.801857  [    0/60000]
    loss: 0.696863  [ 6400/60000]
    loss: 0.638847  [12800/60000]
    loss: 0.805713  [19200/60000]
    loss: 0.751706  [25600/60000]
    loss: 0.675668  [32000/60000]
    loss: 0.638079  [38400/60000]
    loss: 0.893166  [44800/60000]
    loss: 0.885967  [51200/60000]
    loss: 0.713033  [57600/60000]
    Test Error: 
     Accuracy: 78.5%, Avg loss: 0.720166 
    
    Epoch 15
    -------------------------------
    loss: 0.735028  [    0/60000]
    loss: 0.600313  [ 6400/60000]
    loss: 0.638438  [12800/60000]
    loss: 0.746406  [19200/60000]
    loss: 0.738696  [25600/60000]
    loss: 0.622368  [32000/60000]
    loss: 0.647141  [38400/60000]
    loss: 0.877152  [44800/60000]
    loss: 0.862342  [51200/60000]
    loss: 0.672747  [57600/60000]
    Test Error: 
     Accuracy: 78.6%, Avg loss: 0.708830 
    
    

    学習則の比較

    最後に各学習則の訓練誤差を比較してみましょう.

    In [17]:
    iterations = np.arange(len(train_loss_bp)) * 100
    plt.figure(figsize=(5, 4), dpi=100)
    plt.plot(iterations, train_loss_bp, label="Backprop. w/ lr=1e-2")
    plt.plot(iterations, train_loss_dodge, label="DODGE w/ lr=1e-3")
    plt.plot(iterations, train_loss_fgd, label="FGD w/ lr=1e-3")
    plt.xlabel("Iterations")
    plt.ylabel("Train loss")
    plt.legend()
    plt.tight_layout()
    

    学習初期の学習速度はDODGEの方が優れていましたが,最終的には同程度の損失・精度に収束しました.今回はPytorch自体のoptimizerを使用できるような実装にしましたが,Adamを用いてもテストデータに対する予測精度は変わりませんでした (~80%程度,確認にはoptimizaerの部分を変更してみてください).num_directionsを大きくすれば数値的に安定し,大きな学習率でも機能するかもしれません.あるいは学習率を下げて長い時間学習させればテストデータに対する予測精度は向上するかもしれません.いずれにせよ検証はしていないので興味がある方はやってみてください.DODGEもFGDも著者実装が公開されていないのですが,報告を読む限りはより効率的で安定した実装があるのかもしれません.