import torch from torch import nn from d2l import torch as d2l
defcorr2d(X, K): """计算二维互相关运算""" h, w = K.shape Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) for i inrange(Y.shape[0]): for j inrange(Y.shape[1]): Y[i, j] = (X[i:i + h, j:j + w] * K).sum() return Y
X = X.reshape((1, 1, 6, 8)) Y = Y.reshape((1, 1, 6, 7))
for i in range(10): Y_hat = conv2d(X) l = (Y_hat - Y) ** 2 conv2d.zero_grad() l.sum().backward() conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad if (i + 1) % 2 == 0: print(f'epoch {i+1}, loss {l.sum():.3f}')