![Python元学习:通用人工智能的实现](https://wfqqreader-1252317822.image.myqcloud.com/cover/899/32670899/b_32670899.jpg)
1.4 少样本学习的优化模型
我们知道,少样本学习基于较少的数据点,那么如何将梯度下降应用到少样本学习中呢?在少样本学习中,梯度下降会由于数据点非常少而突然失效。梯度下降优化需要更多的数据点来达到收敛和损失最小化。因此,在少样本学习中需要一种更好的优化技术。假设有一个由参数θ影响的模型f。我们用一些随机值来初始化参数θ,并尝试使用梯度下降法找到最优值。让我们回忆一下梯度下降的更新方程:
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0018_0002.jpg?sign=1738889923-PlIqtDHNCZBpmpe0yqpmMtUOGLbkUv6z-0-2bc2c492f8b46beadb846777cea82794)
以上方程的参数解释如下:
❑ θt是更新参数;
❑ θt-1是上一步的参数值;
❑ αt是学习率;
❑ 是相对于θt-1的损失函数的梯度。
梯度下降的更新方程是不是看起来很熟悉?是的,你猜对了,它类似于长短期记忆网络(LSTM)的细胞状态更新方程,可以写成:
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0019_0001.jpg?sign=1738889923-T2VXHIoAL0nOh99la7k4vpkKUICrSnjC-0-7d1b16442bd9688a3aec3d2c512c7b2b)
可以将LSTM细胞更新方程与梯度下降完全对应起来,设ft =1,可得:
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0019_0002.jpg?sign=1738889923-rcYQtFuQ6ihcNurbVju8dRZaBn0jVJir-0-3454c3e33067ab8601044b51a04d3a11)
因此,在少样本学习中,可以使用LSTM而非梯度下降作为优化器。LSTM是元学习器,它将学习用于训练模型的更新规则。因此,我们使用两个网络:一个是基学习器,它学会执行任务;另一个是元学习器,它试图找到最优的参数。这是如何实现的呢?
我们知道,LSTM使用遗忘门(forget gate)来丢弃存储器中不需要的信息,它可以表示为
ft=σ(wf⋅[ht-1,xt]+bf)
这个遗忘门在我们的优化场景中有什么用呢?假设我们处在一个损失很大,梯度接近于零的位置。怎样才能摆脱这种局面呢?在这种情况下,可以收缩模型的参数,并忘记其前一个值的某些部分。我们可以使用遗忘门来实现这一点,它以当前参数值θt-1、当前损失Lt、当前梯度以及前一个遗忘门作为输入。它可以表示为
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0019_0004.jpg?sign=1738889923-ycAguclaCBr8DGsOZWUGUhxsmQsnvDEb-0-6c1921291fa88663e52dc5f24e8ff20f)
下面来看看输入门(input gate)。我们知道LSTM中的输入门是用来决定更新什么值的,它可以表示为
it=σ(wi⋅[ht-1, xt]+bi)
在少样本学习中,可以使用这个输入门来调整学习率,从而在防止发散的同时快速学习:
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0019_0005.jpg?sign=1738889923-ABRAYs2UW97dFXFbsPp9I1Jqk3vf489Q-0-2a002b3e609d323d8fe3e2cce7bed170)
因此,元学习器在多次更新之后得到了it与ft的最优值。
可是,这是如何运作的呢?
假设有一个由θ影响的基网络M、由ϕ影响的LSTM元学习器R,以及数据集D。我们将数据集分割为训练集Dtrain和测试集Dtest。首先随机初始化元学习器参数ϕ。
在T次迭代中,随机从Dtrain中抽取数据点,计算损失以及相对于模型参数θ的损失梯度。将这个梯度、损失和元学习器参数ϕ提供给元学习器。元学习器R会返回细胞状态ct,然后在时间t将基网络M的参数θt更新为ct。重复N次,如图1-3所示。
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0020_0001.jpg?sign=1738889923-6vWDURJKuCBRUnLiKmNpOfrDm9e8KMX4-0-b43d5e0dd757c6b4f2aa719e401ef703)
图1-3
因此,经过T次迭代,我们会得到一个最优参数θT。不过如何检查θT的性能并更新元学习器参数呢?使用测试集和参数θT计算测试集的损失。然后,计算相对于元学习器参数ϕ的损失梯度,并更新ϕ,如图1-4所示。
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0020_0002.jpg?sign=1738889923-CbEmJHqhBquMk4MAYOCMik8ixck0iKK9-0-6d00674b8d72d528e0fb2beb1c33a186)
图1-4
迭代n次,并更新元学习器。完整的算法如图1-5所示。
![](https://epubservercos.yuewen.com/8713C6/17594707006801106/epubprivate/OEBPS/Images/figure_0020_0003.jpg?sign=1738889923-c6bHMkCi6GElR8zWOF8L6jXyaqdyCKLC-0-ffcd447927008028558d697f6ad05986)
图1-5