Attention Backpropagation: Step by step derivation

deadf00d1 pts0 comments

Attention Backpropagation | Liyuan’s Log

Skip to the content.

I recently revisited the FlashAttention[1] and FlashAttention2[2] papers. It is really fun to manually derive the backward pass of the attention.<br>In this blog, I will use a concrete example to illustrate this process and hope it is easy to understand.

Forward Pass

So attention[3] involves 3 matrices: $Q$, $K$, $V$. The matrix shape is [batch_size, num_heads, seq_len, head_dim]. Attention is calculated as follows:

\[\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{head\_dim}})V\]

Let me use a simple example to illustrate this process. We will ignore $batch_size$ and $num_heads$ dimension in this example because the matrix multiplication is on $seq_len$ and $head_dim$ dimensions. And we will also ignore the scaling factor $\frac{1}{\sqrt{head_dim}}$ for simplicity.

\[Q = \begin{bmatrix}<br>q_{11} & q_{12} & q_{13} \\<br>q_{21} & q_{22} & q_{23} \\<br>q_{31} & q_{32} & q_{33}<br>\end{bmatrix}\]

\[K = \begin{bmatrix}<br>k_{11} & k_{12} & k_{13} \\<br>k_{21} & k_{22} & k_{23} \\<br>k_{31} & k_{32} & k_{33}<br>\end{bmatrix}\]

\[V = \begin{bmatrix}<br>v_{11} & v_{12} & v_{13} \\<br>v_{21} & v_{22} & v_{23} \\<br>v_{31} & v_{32} & v_{33}<br>\end{bmatrix}\]

So

\[QK^T = S = \begin{bmatrix}<br>q_{11}k_{11} + q_{12}k_{21} + q_{13}k_{31} & q_{11}k_{12} + q_{12}k_{22} + q_{13}k_{32} & q_{11}k_{13} + q_{12}k_{23} + q_{13}k_{33} \\<br>q_{21}k_{11} + q_{22}k_{21} + q_{23}k_{31} & q_{21}k_{12} + q_{22}k_{22} + q_{23}k_{32} & q_{21}k_{13} + q_{22}k_{23} + q_{23}k_{33} \\<br>q_{31}k_{11} + q_{32}k_{21} + q_{33}k_{31} & q_{31}k_{12} + q_{32}k_{22} + q_{33}k_{32} & q_{31}k_{13} + q_{32}k_{23} + q_{33}k_{33}<br>\end{bmatrix} = \begin{bmatrix}<br>s_{11} & s_{12} & s_{13} \\<br>s_{21} & s_{22} & s_{23} \\<br>s_{31} & s_{32} & s_{33}<br>\end{bmatrix}\]

\[P = \text{softmax}(S) = \begin{bmatrix}<br>\frac{exp(s_{11})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \frac{exp(s_{12})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} & \frac{exp(s_{13})}{exp(s_{11}) + exp(s_{12}) + exp(s_{13})} \\<br>\frac{exp(s_{21})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \frac{exp(s_{22})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} & \frac{exp(s_{23})}{exp(s_{21}) + exp(s_{22}) + exp(s_{23})} \\<br>\frac{exp(s_{31})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \frac{exp(s_{32})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})} & \frac{exp(s_{33})}{exp(s_{31}) + exp(s_{32}) + exp(s_{33})}<br>\end{bmatrix} = \begin{bmatrix}<br>p_{11} & p_{12} & p_{13} \\<br>p_{21} & p_{22} & p_{23} \\<br>p_{31} & p_{32} & p_{33}<br>\end{bmatrix}\]

\[O = PV = \begin{bmatrix}<br>p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31} & p_{11}v_{12} + p_{12}v_{22} + p_{13}v_{32} & p_{11}v_{13} + p_{12}v_{23} + p_{13}v_{33} \\<br>p_{21}v_{11} + p_{22}v_{21} + p_{23}v_{31} & p_{21}v_{12} + p_{22}v_{22} + p_{23}v_{32} & p_{21}v_{13} + p_{22}v_{23} + p_{23}v_{33} \\<br>p_{31}v_{11} + p_{32}v_{21} + p_{33}v_{31} & p_{31}v_{12} + p_{32}v_{22} + p_{33}v_{32} & p_{31}v_{13} + p_{32}v_{23} + p_{33}v_{33}<br>\end{bmatrix} = \begin{bmatrix}<br>o_{11} & o_{12} & o_{13} \\<br>o_{21} & o_{22} & o_{23} \\<br>o_{31} & o_{32} & o_{33}<br>\end{bmatrix}\]

$O$ is the output of the attention.

Backward Pass

When we do backward pass, the input is the partial derivative of loss with respect to $O$.

\[\frac{\partial L}{\partial O} = \begin{bmatrix}<br>\frac{\partial L}{\partial o_{11}} & \frac{\partial L}{\partial o_{12}} & \frac{\partial L}{\partial o_{13}} \\<br>\frac{\partial L}{\partial o_{21}} & \frac{\partial L}{\partial o_{22}} & \frac{\partial L}{\partial o_{23}} \\<br>\frac{\partial L}{\partial o_{31}} & \frac{\partial L}{\partial o_{32}} & \frac{\partial L}{\partial o_{33}}<br>\end{bmatrix}\]

When we use the deep learning framework like Pytorch, Jax, this derivative is automatically computed. And we will use this derivative to compute the gradient of $\frac{\partial L}{\partial Q}$, $\frac{\partial L}{\partial K}$, $\frac{\partial L}{\partial V}$.

Gradient of $V$ and $P$

This is most straightforward. Remember that $O = PV$,

\[O = PV = \begin{bmatrix}<br>p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31} & p_{11}v_{12} + p_{12}v_{22} + p_{13}v_{32} & p_{11}v_{13} + p_{12}v_{23} + p_{13}v_{33} \\<br>p_{21}v_{11} + p_{22}v_{21} + p_{23}v_{31} & p_{21}v_{12} + p_{22}v_{22} + p_{23}v_{32} & p_{21}v_{13} + p_{22}v_{23} + p_{23}v_{33} \\<br>p_{31}v_{11} + p_{32}v_{21} + p_{33}v_{31} & p_{31}v_{12} + p_{32}v_{22} + p_{33}v_{32} & p_{31}v_{13} + p_{32}v_{23} + p_{33}v_{33}<br>\end{bmatrix} = \begin{bmatrix}<br>o_{11} & o_{12} & o_{13} \\<br>o_{21} & o_{22} & o_{23} \\<br>o_{31} & o_{32} & o_{33}<br>\end{bmatrix}\]

So for example $\frac{\partial L}{\partial v_{11}}$, it appears in the first column of $O$, so

\[\frac{\partial L}{\partial v_{11}} = \frac{\partial L}{\partial o_{11}}\frac{\partial o_{11}}{\partial v_{11}} + \frac{\partial L}{\partial o_{21}}\frac{\partial o_{21}}{\partial v_{11}} + \frac{\partial L}{\partial o_{31}}\frac{\partial o_{31}}{\partial v_{11}}\]

Since $o_{11} = p_{11}v_{11} + p_{12}v_{21} + p_{13}v_{31}$,

\[\frac{\partial o_{11}}{\partial v_{11}} =...

partial frac bmatrix begin attention pass

Related Articles