Java 中的梯度下降

今天的大多数人工智能都是使用某种形式的神经网络实现的。在我的前两篇文章中,我介绍了神经网络并向您展示了如何在 Java 中构建神经网络。神经网络的力量主要来自于它的深度学习能力,而这种能力建立在梯度下降反向传播的概念和执行之上。我将通过快速深入了解 Java 中的反向传播和梯度下降来结束这个简短的系列文章。
有人说人工智能并不是那么智能,它主要是反向传播。那么,现代机器学习的基石是什么?
要了解反向传播,您必须首先了解神经网络的工作原理。基本上,神经网络是称为神经元的节点的有向图。神经元有一个特定的结构,它接受输入,将它们与权重相乘,添加一个偏差值,并通过激活函数运行所有这些。神经元将它们的输出馈送到其他神经元,直到到达输出神经元。输出神经元产生网络的输出。 (有关更完整的介绍,请参阅机器学习风格:神经网络简介。)
从这里开始,我假设您了解网络及其神经元的结构,包括前馈。示例和讨论将集中在梯度下降的反向传播上。我们的神经网络将有一个输出节点、两个“隐藏”节点和两个输入节点。使用一个相对简单的例子可以更容易地理解算法所涉及的数学。图 1 显示了示例神经网络的图表。
图 1. 我们将用于示例的神经网络图。
梯度下降反向传播的思想是将整个网络视为一个多元函数,为损失函数提供输入。损失函数通过将网络输出与已知的良好结果进行比较来计算一个表示网络执行情况的数字。与良好结果配对的输入数据集称为训练集。损失函数旨在随着网络行为远离正确而增加数值。
梯度下降算法采用损失函数并使用偏导数来确定网络中每个变量(权重和偏差)对损失值的贡献。然后它向后移动,访问每个变量并调整它以减少损失值。
理解梯度下降涉及微积分中的一些概念。首先是导数的概念。 MathsIsFun.com 对导数有很好的介绍。简而言之,导数为您提供函数在单个点处的斜率(或变化率)。换句话说,函数的导数为我们提供了给定输入的变化率。 (微积分的美妙之处在于它可以让我们在没有其他参考点的情况下找到变化——或者更确切地说,它可以让我们假设输入的变化非常小。)
下一个重要的概念是偏导数。偏导数让我们采用多维(也称为多变量)函数并仅隔离其中一个变量以找到给定维度的斜率。
导数回答了以下问题:函数在特定点的变化率(或斜率)是多少?偏导数回答了以下问题:给定方程的多个输入变量,仅这一个变量的变化率是多少?
梯度下降使用这些思想来访问方程中的每个变量并对其进行调整以最小化方程的输出。这正是我们训练网络时想要的。如果我们将损失函数视为绘制在图表上,我们希望以增量方式向函数的最小值移动。也就是说,我们要找到全局最小值。
请注意,增量的大小在机器学习中称为“学习率”。
当我们探索梯度下降反向传播的数学时,我们将紧贴代码。当数学变得过于抽象时,查看代码将有助于让我们脚踏实地。让我们首先查看我们的 Neuron 类,如清单 1 所示。
Neuron 类只有三个 Double 成员:weight1、weight2 和 bias。它也有一些方法。用于前馈的方法是 compute()。它接受两个输入并执行神经元的工作:将每个输入乘以适当的权重,加上偏差,然后通过 sigmoid 函数运行它。
在我们继续之前,让我们重新审视一下 sigmoid 激活的概念,我在神经网络简介中也讨论过它。清单 2 显示了一个基于 Java 的 sigmoid 激活函数。
sigmoid 函数接受输入并将欧拉数 (Math.exp) 提高到负数,加 1 再除以 1。效果是将输出压缩在 0 和 1 之间,越来越大和越来越小的数字逐渐接近极限。
DeepAI.org 对机器学习中的 sigmoid 函数有很好的介绍。
回到清单 1 中的 Neuron 类,除了 compute() 方法之外,我们还有 getSum() 和 getDerivedOutput()。 getSum() 只是进行权重 * 输入 + 偏差计算。请注意,compute() 采用 getSum() 并通过 sigmoid() 运行它。 getDerivedOutput() 方法通过一个不同的函数运行 getSum():sigmoid 函数的导数。
现在看一下清单 3,它显示了 Java 中的 sigmoid 导数函数。我们已经从概念上讨论了衍生品,下面是实际应用。
记住导数告诉我们函数在其图中的单个点的变化是什么,我们可以感受一下这个导数在说什么:告诉我给定输入的 sigmoid 函数的变化率。您可以说它告诉我们清单 1 中的预激活神经元对最终激活结果有何影响。
您可能想知道我们如何知道清单 3 中的 sigmoid 导数函数是正确的。答案是,如果它已经被其他人验证过,并且如果我们知道根据特定规则正确微分的函数是准确的,我们就会知道导数函数是正确的。一旦我们理解了它们在说什么并相信它们是准确的,我们就不必回到第一原理并重新发现这些规则——就像我们接受并应用简化代数方程式的规则一样。
所以,在实践中,我们是按照求导法则来求导数的。如果您查看 sigmoid 函数及其导数,您会发现后者可以通过遵循这些规则得出。出于梯度下降的目的,我们需要了解导数规则,相信它们有效,并了解它们的应用方式。我们将使用它们来找出每个权重和偏差在网络最终损失结果中所扮演的角色。
符号 f prime f'(x) 是“f 对 x 的导数”的一种表达方式。另一个是:
两者是等价的:
您很快就会看到的另一种表示法是偏导数表示法:
这就是说,给我变量 x 的 f 的导数。
最令人好奇的衍生规则是链式法则。它说当一个函数是复合的(函数中的函数,又名高阶函数)时,你可以像这样扩展它:
我们将使用链式法则来解压我们的网络并获得每个权重和偏差的偏导数。

关注公众号“大模型全栈程序员”回复“小程序”获取1000个小程序打包源码。更多免费资源在http://www.gitweixin.com/?p=2627