Linearilty of LayerNormalization

LayerNormalization is defined by the following equation: \[ \begin{equation} \text{LayerNorm}(\bm{x}) = \frac{\bm{x} - \mu(\bm{x})\bm{1}}{s(\bm{x})}\odot \bm{\gamma} + \bm{\beta} \label{eq:lm_f} \end{equation} \] Here, \(\mu(\bm{x})\) and \(s(\bm{x})\) is the element-wise mean and the standard deviation of \(\bm{x} \in \mathbb{R}^{d}\), respectively, \(\bm{1}\) is a vector of ones with the same shape as \(\bm{x}\), \(\bm{\gamma}, \bm{\beta} \in \mathbb{R}^{d}\) are learnable parameters, and \(\odot\) represents element-wise multiplication.

Now, \[ \begin{align} \mu(\bm{x})\bm{1} &= \left(\frac{1}{d}\sum_k\bm{x}^{(k)}\right)\bm{1}\\ &=\left(\frac{1}{d}\bm{x}\cdot\bm{1}\right)\bm{1}\\ &=\left(\frac{1}{d}\bm{x}\bm{1}^\top\right)\bm{1}\\ &=\bm{x}\left(\frac{1}{d}\bm{1}^\top\bm{1}\right) \end{align} \] with \(\bm{x}^{(k)}\) representing the \(k\)-th element of \(\bm{x}\).

Also, element-wise multiplication of \(\bm{\gamma}\) can be expressed as matrix multiplication of \(\text{diag}(\bm{\gamma})\). Therefore, LayerNorm can be rewritten as: \[ \begin{align} \text{LayerNorm}(\bm{x}) &= \frac{1}{s(\bm{x})}\left(\bm{x} - \mu(\bm{x})\right)\odot \bm{\gamma} + \bm{\beta}\\ &= \frac{\bm{x}}{s(\bm{x})}\left(I - \frac{1}{d}\bm{1}^\top\bm{1}\right)\text{diag}(\bm{\gamma}) + \bm{\beta}\\ \end{align} \] The only non-linear operation in LayerNorm is the division by \(s(\bm{x})\).

By the way, \(I - \frac{1}{d}\bm{1}^\top\bm{1}\) is called the centering matrix.