PyTorch基本用法(九)——优化器 | | PyTorch基本用法(九)——优化器 文章作者:Tyan博客:noahsnail.com | CSDN | 简书 本文主要是关于PyTorch的一些用法。 1234567891011121314151617181920import torchimport matplotlib.pyplot as pltimport torch.nn.functional as Fimport torch.utils.data as Datafrom torch.autograd import Variable# 定义超参数LR = 0.01BATCH_SIZE = 32EPOCH = 10# 生成数据x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim = 1)y = x.pow(2) + 0.1 * torch.normal(torch.zeros(x.size()))# 绘制数据图像plt.scatter(x.numpy(), y.numpy())plt.show() 123456789101112131415161718# 定义数据库dataset = Data.TensorDataset(data_tensor = x, target_tensor = y)# 定义数据加载器loader = Data.DataLoader(dataset = dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)# 定义pytorch网络class Net(torch.nn.Module): def __init__(self, n_features, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_features, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) y = self.predict(x) return y 123456789101112131415161718192021222324252627282930313233343536373839404142434445# 定义不同的优化器网络net_SGD = Net(1, 10, 1)net_Momentum = Net(1, 10, 1)net_RMSprop = Net(1, 10, 1)net_Adam = Net(1, 10, 1)# 选择不同的优化方法opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr = LR)opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr = LR, momentum = 0.9)opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr = LR, alpha = 0.9)opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr = LR, betas= (0.9, 0.99))nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]# 选择损失函数loss_func = torch.nn.MSELoss()# 不同方法的lossloss_SGD = []loss_Momentum = []loss_RMSprop =[]loss_Adam = []# 保存所有losslosses = [loss_SGD, loss_Momentum, loss_RMSprop, loss_Adam]# 执行训练for epoch in xrange(EPOCH): for step, (batch_x, batch_y) in enumerate(loader): var_x = Variable(batch_x) var_y = Variable(batch_y) for net, optimizer, loss_history in zip(nets, optimizers, losses): # 对x进行预测 prediction = net(var_x) # 计算损失 loss = loss_func(prediction, var_y) # 每次迭代清空上一次的梯度 optimizer.zero_grad() # 反向传播 loss.backward() # 更新梯度 optimizer.step() # 保存loss记录 loss_history.append(loss.data[0]) 123456789# 画图labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']for i, loss_history in enumerate(losses): plt.plot(loss_history, label = labels[i])plt.legend(loc = 'best')plt.xlabel('Steps')plt.ylabel('Loss')plt.ylim((0, 0.2))plt.show() 参考资料 https://www.youtube.com/user/MorvanZhou 如果有收获,可以请我喝杯咖啡! 赏 微信打赏 支付宝打赏