BatchNorm|LayerNorm|RMSNorm

最近在复习之前学过的Transformer的底层原理,又一次碰到了LayerNorm。之前在做一些深度估计项目的时候,使用ViT架构,也用到了LayerNorm。但仅仅将它作为了一个黑箱,并没有深刻的理解它的原理。而且与LayerNorm对应的BatchNorm,虽然之前阅读过原论文,但仍处于一知半解的阶段,大部分也忘记了。借助这个blog,重新学习和回忆一下。并对比两者的不同,并动手实现相应的模块。

update 2025-01-15:加入RMSNorm

Batch Normalization

训练一个神经网络其实是比较困难的,因为神经网络的学习过程是反向传播(Back-propagation):通过梯度下降(Gradient Descent)的方式一步步反向更新每一个神经元的权重。神经网络一般由多个层构成,这样就造成了一个问题:

假设我们有两层神经网络,A和B,其中A的输入是x,输出是y,B以y为输入并输出结果z。我们通过计算预测结果z和真实标签之间的损失来进行反向传播,进而更新参数。更新之后,A和B的参数分别变为A'和B'。当我们再次将x输入网络时,A'的输出将不再是之前的y,可能变成了另外一个值。

这就导致了一个问题:B原本是基于旧的y输入学习的,现在输入变了,他之前学到的规律不再适用,需要重新适应新的数据分布。这种现象被称为Internal covariate shift

Batch Normalization提出的动机就是缓解这种偏移,提高训练的稳定性。

但Santurkar等人的工作How Does Batch Normalization Help Optimization中指出,使得Batch Normalization成功的原因并非因为Internal covariate shift,甚至在某种情形下BatchNorm没有减少internal covariate shift。这篇blog不讨论谁的观点绝对正确,仅仅对BatchNorm的思想做总结。

Why Batch Normalization?

一定程度消除Internal covariate shift:在论文Introduction部分讲述了什么是internal covariate shift(ICS)。ICS导致每一层在更新参数后重新适应新的输入的变化,降低了学习的效率。采用BatchNorm,将输入归一化为均值为0和方差为1的数据,不仅仅可以一定程度解决梯度消失问题(不归一化数值会进入saturated zone,就是非线性激活函数的两端),还能使不同层之间近乎于独立学习,降低了层与层之间的耦合性。

BatchNorm的平滑效应:BatchNorm将优化问题的landscape从“很瘦长的椭圆”转换为更平滑和对称的正圆(提升了损失函数的Lipschitzness和其梯度的Lipschitzness,关于Lipschitzness还在补充)。确保了问题的平滑性,就减少了初始值和学习率对神经网络的影响。我们可以采用更大的学习率加速网络学习,不必太过担心网络进入局部极小值。

Mathematical Description

假设对一个layer的输入是,那么对这个输入的归一化如下: 其中

表示mini-batch的batch size,表示输入的第几个特征。

但如果就这样硬生生将每一层的activation后的输出归一化的话,会让模型丧失原有的表达能力。以下是ChatGPT的回答:

The linear transformation process in Batch Normalization, which involves the learnable parameters (scale factor and shift ), is crucial for the following reasons:

  1. Restoring the Representational Power of the Network: After applying Batch Normalization, the activations of the layer are normalized to have zero mean and unit variance. While this normalization process helps to stabilize learning and reduce internal covariate shift, it can also limit what the layer can represent. For instance, in some cases, the network might learn that the best representation of the data for the subsequent layers is not zero-mean/unit-variance. The scale and shift transformation allows the network to learn the most suitable scale and location of the activations, thereby restoring the representational power of the network.

  2. Preserving the Expressive Power of Activation Functions: Certain activation functions like ReLU and its variants have different behaviors in different regions of the input space. For instance, the ReLU function is sensitive to positive inputs and insensitive to negative inputs. If Batch Normalization is used without the scale and shift, the activations would be mostly confined to the region where ReLU is active, thereby limiting the expressive power of the activation function. The scale and shift transformation allows the network to learn to use the full expressive power of the activation function.

  3. Flexibility: The learnable parameters and provide the network with the flexibility to learn the optimal scale and mean of the activations. If the optimal scale and mean are indeed 1 and 0 respectively, the network can learn γ close to 1 and β close to 0. But if they are not, the network has the flexibility to learn other values.

In summary, the linear transformation process in Batch Normalization, governed by the learnable parameters and , is crucial for preserving the expressive power of the network and providing it with the flexibility to learn the most suitable representations.

因此在归一化之后,需要加入两个可学习的参数,给予模型自由学习的能力。让其在学习的过程中自己寻找最适合的分布状态。其数学描述为: 在训练阶段,BatchNorm按照上述的数学描述进行学习。

但在验证和测试阶段,BatchNorm则有所不同。因为在训练阶段,我们是以mini-batch的形式将数据喂入模型的。训练过程中,会实时计算mini-batch的均值和方差。而如果在测试阶段也这样做,尤其是每次输入一个测试数据的时候,mini-batch的size是1,最后会得0。这显然是不合理的。因此在工程上,采用running mean和running variance的方法,会在训练阶段实时更新。

更新过程为:

Layer Normalization

加速神经网络训练的方法可以用之前提到的batch normalization。但是有两点需要注意:

  • Batch normalization比较依赖于mini-batch的大小,mini-batch越大,batch norm的效果越好
  • Batch normalization似乎对于RNN这种处理sequence数据的模型适用难度较高

基于这两点因素,作者提出了Layer Normalization。

Why Layer Normalization?

  • 因为Batch normalization使用的方差和均值是基于mini-batch对整体的估计,这说明其受限于mini-batch的大小。
  • Batch normalization在序列模型中不太适用,因为序列模型的输入经常是变长的。

Mathematical Description

假设对一个layer的输入是,那么对这个输入的归一化如下: 其中

为了保留模型的表达能力,还是加入两个可学习的参数做一个线性变换。 Layer Normalization的训练和测试阶段的行为一直,因此不需要额外加入其他变量来做记录。仅需要在推理的时候,计算输入数据的均值和方差就可以。

Root Mean Square Layer Normalization

RMSNorm是另一种归一化方法,其核心思想就是通过对输入的向量进行归一化缩放,提升训练的稳定性和效率。

Why RMS Normalization

  • 简化计算、提高效率:相比于Layer Normalization,RMSNorm节省了计算均值的步骤,仅仅计算平方均值的根(RMS)
  • 在Transformer等对计算效率要求高的场景中,可以显著加速训练
  • 通过实验表明,在语言建模任务重,RMSNorm的梯度方差比LN更低,训练曲线更平滑

Mathematical Description

假设某一层的输入为向量 ,RMSNorm 的归一化过程如下: 计算均方根(RMS): 其中, 是一个小的常数,用于防止除以零。 归一化输入: 应用可学习的缩放参数: 这里, 是可学习的参数,用于恢复模型的表达能力。

Reference

  • Batch Normalizaiton: Accelerating Deep Network Training by Reducing Internal Covariate Shift
  • How Does Batch Normalization Help Optimization
  • https://www.cnblogs.com/smartljy/p/18747158