Tensorの自動微分

PyTorchなどのディープラーニングライブラリの特徴は、自動微分を備えている事です。自動微分を行うTensorをつくるためには、Tensorの初期化時に requires_grad=True を指定します。

In [1]: import torch

In [2]: torch.ones(2, 2, requires_grad=True)
Out[2]: 
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

実際に、((x+1)**2).sum()の微分を計算してみましょう。まず、 y((x+1)**2).sum()を代入します。PyTorchの自動微分では、一度 sum().mean()をとってスカラー値にしないと自動微分を実行できません。

In [3]: x = torch.ones(2, 2, requires_grad=True)

In [4]: y = ((x+1)**2).sum()

In [5]: y
Out[5]: 
tensor([[4., 4.],
        [4., 4.]], grad_fn=<PowBackward0>)

yを表示すると、grad_fn=<PowBackward0>のように勾配を計算する関数が書かれています。具体的にどの関数で微分を計算するのかは利用者はあまり気にする必要はありません。微分を計算するには y.backward()とします。すると、x.gradに勾配が記録されます。

In [16]: y.backward()

In [17]: x.grad
Out[17]: 
tensor([[5., 5.],
        [5., 5.]])

このように出力されました。 手計算と照合してみましょう。

y=(x+1)2yx00=2x+1x00=1yx00=5 \begin{aligned} \mathbf{y} &= \sum (\mathbf {x}+1)^2\\ \frac{\partial y}{\partial x_{00}} &= 2x + 1\\ x_{00} &= 1 \text{より}\\ \frac{\partial y}{\partial x_{00}} &= 5\\ \end{aligned}

正しく計算されています。他は同様ですから省略します。

PyTorchでは、用意されている様々な関数の微分を計算できます。次のlのような複雑な関数の微分を求める事もできます。

In [153]: W1 = torch.randn([3, 10], requires_grad=True)

In [154]: W2 = torch.randn([10, 10], requires_grad=True)

In [155]: W3 = torch.randn([10, 1], requires_grad=True)

In [156]: x = torch.Tensor([1., -2., 3.])

In [158]: y = F.celu(torch.matmul(W3.T, F.celu(torch.matmul(W2.T, F.celu(torch.matmul(W1.T, x))))))

In [159]: t = torch.zeros([1])

In [160]: l = torch.mean((y - t)**2)

In [161]: l
Out[161]: tensor(0.9990, grad_fn=<MeanBackward0>)

In [162]: l.backward()

In [163]: W1.grad
Out[163]: 
tensor([[ 0.0008,  0.0010, -0.0024, -0.0006,  0.0006, -0.0003,  0.0010, -0.0014,
          0.0054, -0.0044],
        [-0.0015, -0.0021,  0.0047,  0.0013, -0.0013,  0.0007, -0.0020,  0.0028,
         -0.0109,  0.0089],
        [ 0.0023,  0.0031, -0.0071, -0.0019,  0.0019, -0.0010,  0.0030, -0.0042,
          0.0163, -0.0133]])

In [164]: W2.grad
Out[164]: 
tensor([[-3.8166e-05,  6.0304e-04,  1.0604e-05, -4.4769e-06,  6.6196e-05,
          4.3861e-04,  4.5371e-04, -3.2250e-04, -8.1333e-04, -9.1162e-04],
        [ 8.3692e-05, -1.3224e-03, -2.3252e-05,  9.8171e-06, -1.4516e-04,
         -9.6181e-04, -9.9492e-04,  7.0720e-04,  1.7835e-03,  1.9991e-03],
        [ 7.9589e-05, -1.2575e-03, -2.2112e-05,  9.3358e-06, -1.3804e-04,
         -9.1466e-04, -9.4615e-04,  6.7253e-04,  1.6961e-03,  1.9011e-03],
        [-3.5049e-05,  5.5379e-04,  9.7376e-06, -4.1112e-06,  6.0790e-05,
          4.0279e-04,  4.1665e-04, -2.9616e-04, -7.4690e-04, -8.3717e-04],
        [ 1.1731e-04, -1.8536e-03, -3.2593e-05,  1.3761e-05, -2.0347e-04,
         -1.3482e-03, -1.3946e-03,  9.9131e-04,  2.5000e-03,  2.8022e-03],
        [-4.5904e-05,  7.2530e-04,  1.2753e-05, -5.3846e-06,  7.9618e-05,
          5.2754e-04,  5.4570e-04, -3.8789e-04, -9.7823e-04, -1.0965e-03],
        [ 8.8772e-05, -1.4026e-03, -2.4664e-05,  1.0413e-05, -1.5397e-04,
         -1.0202e-03, -1.0553e-03,  7.5013e-04,  1.8918e-03,  2.1204e-03],
        [ 5.1163e-05, -8.0839e-04, -1.4215e-05,  6.0014e-06, -8.8739e-05,
         -5.8798e-04, -6.0822e-04,  4.3233e-04,  1.0903e-03,  1.2221e-03],
        [ 1.8080e-04, -2.8568e-03, -5.0233e-05,  2.1209e-05, -3.1360e-04,
         -2.0779e-03, -2.1494e-03,  1.5278e-03,  3.8530e-03,  4.3187e-03],
        [ 1.5403e-04, -2.4338e-03, -4.2795e-05,  1.8068e-05, -2.6716e-04,
         -1.7702e-03, -1.8311e-03,  1.3016e-03,  3.2825e-03,  3.6792e-03]])

In [165]: W3.grad
Out[165]: 
tensor([[ 0.0010],
        [-0.0019],
        [ 0.0010],
        [ 0.0010],
        [ 0.0009],
        [-0.0029],
        [-0.0017],
        [-0.0077],
        [-0.0005],
        [-0.0069]])

results matching ""

    No results matching ""