Tensorの自動微分
PyTorchなどのディープラーニングライブラリの特徴は、自動微分を備えている事です。自動微分を行うTensorをつくるためには、Tensorの初期化時に requires_grad=True
を指定します。
In [1]: import torch
In [2]: torch.ones(2, 2, requires_grad=True)
Out[2]:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
実際に、((x+1)**2).sum()
の微分を計算してみましょう。まず、 y
に ((x+1)**2).sum()
を代入します。PyTorchの自動微分では、一度 sum()
や.mean()
をとってスカラー値にしないと自動微分を実行できません。
In [3]: x = torch.ones(2, 2, requires_grad=True)
In [4]: y = ((x+1)**2).sum()
In [5]: y
Out[5]:
tensor([[4., 4.],
[4., 4.]], grad_fn=<PowBackward0>)
y
を表示すると、grad_fn=<PowBackward0>
のように勾配を計算する関数が書かれています。具体的にどの関数で微分を計算するのかは利用者はあまり気にする必要はありません。微分を計算するには y.backward()
とします。すると、x.grad
に勾配が記録されます。
In [16]: y.backward()
In [17]: x.grad
Out[17]:
tensor([[5., 5.],
[5., 5.]])
このように出力されました。 手計算と照合してみましょう。
正しく計算されています。他は同様ですから省略します。
PyTorchでは、用意されている様々な関数の微分を計算できます。次のl
のような複雑な関数の微分を求める事もできます。
In [153]: W1 = torch.randn([3, 10], requires_grad=True)
In [154]: W2 = torch.randn([10, 10], requires_grad=True)
In [155]: W3 = torch.randn([10, 1], requires_grad=True)
In [156]: x = torch.Tensor([1., -2., 3.])
In [158]: y = F.celu(torch.matmul(W3.T, F.celu(torch.matmul(W2.T, F.celu(torch.matmul(W1.T, x))))))
In [159]: t = torch.zeros([1])
In [160]: l = torch.mean((y - t)**2)
In [161]: l
Out[161]: tensor(0.9990, grad_fn=<MeanBackward0>)
In [162]: l.backward()
In [163]: W1.grad
Out[163]:
tensor([[ 0.0008, 0.0010, -0.0024, -0.0006, 0.0006, -0.0003, 0.0010, -0.0014,
0.0054, -0.0044],
[-0.0015, -0.0021, 0.0047, 0.0013, -0.0013, 0.0007, -0.0020, 0.0028,
-0.0109, 0.0089],
[ 0.0023, 0.0031, -0.0071, -0.0019, 0.0019, -0.0010, 0.0030, -0.0042,
0.0163, -0.0133]])
In [164]: W2.grad
Out[164]:
tensor([[-3.8166e-05, 6.0304e-04, 1.0604e-05, -4.4769e-06, 6.6196e-05,
4.3861e-04, 4.5371e-04, -3.2250e-04, -8.1333e-04, -9.1162e-04],
[ 8.3692e-05, -1.3224e-03, -2.3252e-05, 9.8171e-06, -1.4516e-04,
-9.6181e-04, -9.9492e-04, 7.0720e-04, 1.7835e-03, 1.9991e-03],
[ 7.9589e-05, -1.2575e-03, -2.2112e-05, 9.3358e-06, -1.3804e-04,
-9.1466e-04, -9.4615e-04, 6.7253e-04, 1.6961e-03, 1.9011e-03],
[-3.5049e-05, 5.5379e-04, 9.7376e-06, -4.1112e-06, 6.0790e-05,
4.0279e-04, 4.1665e-04, -2.9616e-04, -7.4690e-04, -8.3717e-04],
[ 1.1731e-04, -1.8536e-03, -3.2593e-05, 1.3761e-05, -2.0347e-04,
-1.3482e-03, -1.3946e-03, 9.9131e-04, 2.5000e-03, 2.8022e-03],
[-4.5904e-05, 7.2530e-04, 1.2753e-05, -5.3846e-06, 7.9618e-05,
5.2754e-04, 5.4570e-04, -3.8789e-04, -9.7823e-04, -1.0965e-03],
[ 8.8772e-05, -1.4026e-03, -2.4664e-05, 1.0413e-05, -1.5397e-04,
-1.0202e-03, -1.0553e-03, 7.5013e-04, 1.8918e-03, 2.1204e-03],
[ 5.1163e-05, -8.0839e-04, -1.4215e-05, 6.0014e-06, -8.8739e-05,
-5.8798e-04, -6.0822e-04, 4.3233e-04, 1.0903e-03, 1.2221e-03],
[ 1.8080e-04, -2.8568e-03, -5.0233e-05, 2.1209e-05, -3.1360e-04,
-2.0779e-03, -2.1494e-03, 1.5278e-03, 3.8530e-03, 4.3187e-03],
[ 1.5403e-04, -2.4338e-03, -4.2795e-05, 1.8068e-05, -2.6716e-04,
-1.7702e-03, -1.8311e-03, 1.3016e-03, 3.2825e-03, 3.6792e-03]])
In [165]: W3.grad
Out[165]:
tensor([[ 0.0010],
[-0.0019],
[ 0.0010],
[ 0.0010],
[ 0.0009],
[-0.0029],
[-0.0017],
[-0.0077],
[-0.0005],
[-0.0069]])