
上QQ阅读APP看书,第一时间看更新
1.2.3 计算损失
损失函数需要一对输入:模型输出和目标,用来评估输出距离目标有多远。损失用loss来表示,损失函数的作用就是计算神经网络每次迭代的前向计算结果与真实值之间的差距,从而指导模型下一步训练往正确的方向进行。常见的损失函数有交叉熵损失函数和均方误差损失函数。
在PyTorch中,nn库模块提供了多种损失函数,常用的有以下几种:处理回归问题的nn.MSELoss函数,处理二分类的nn.BCELoss函数,处理多分类的nn.CrossEntropyLoss函数,由于本次MNIST数据集是10个分类,因此选择nn.CrossEntropyLoss函数,代码如下:

输出为tensor(2.2725,grad_fn=<NllLossBackward>),表明当前两个样本通过网络输出后与实际差距仍有2.2725,我们的训练目标就是最小化loss值。