LoRA and Weight Decay (2023)

jxmorris121 pts0 comments

irhum.github.io - LoRA and Weight Decay

LoRA (Hu et al., 2021) is a now popular alternative to the full finetuning of a Large Language Models (LLMs): instead of tuning the billions of weights of the full model, we add small “adapter” weight matrices that modify the original weight matrices, and tune those instead.

This blogpost dives deeper into a curious behavior: although LoRA is commonly seen an drop-in for full finetuning, its interaction with weight decay means it solves a different optimization problem than full finetuning. Namely, one where the solution weights are regularized towards the frozen base model \((W \rightarrow W_{\text{init}})\), instead of \(W \rightarrow 0\) as in full finetuning.

This means, given increasingly more resources (even equalling that of full finetuning), LoRA does not increasingly better approximate full finetuning, because its objective function is implicitly different to that of full finetuning. This, depending on use case can either be seen as a bug or a feature , but is something practitioners should explicitly account for.

Recap: Finetuning

With LLMs, we typically finetune an initial model (that is “good” on a wide range of text-to-text tasks) to boost performance on a specific task of interest (e.g. generating database queries from natural language). We do this in a two-step process:

First, creating a finetuning training dataset \({(x_i, y_i)_n}\), which contain pairs of inputs \(x\) and targets \(y\).1

Optimize the weights of the initial model such that our finetuning training dataset \({(x_i, y_i)_n}\) becomes more “probable”. The idea here is that a model that is more likely to generate the correct answers \(y\) on \(x\)’s from our training set, will generalize and also be more likely to generate \(y\)’s on new \(x\)’s.

Full Finetuning

Full finetuning means we tune all the weights of the model. For a model such as GPT-3 175B (Brown et al., 2020), this means giving our optimization algorithm 175 billion numbers it can “dial” up and down as needed to make our finetuning training data more “probable”. Let’s dig a bit deeper, and more concretely define what we mean by weights here.

Each layer in a Transformer is primarily made of two components: a multihead attention network, followed by a feedforward network. This means the bulk of the “weights” that make up each layer are stored in six matrices2, as shown. \(\theta\) then, is used as shorthand refer to all the weights, stored in all the matrices across all the layers of the model.

In full finetuning, every single weight in \(\theta\) is opened up for updating. Our aim is to produce updated weights that minimize the negative log likelihood (NLL) as shown on the left3. There’s no closed form way to get the “optimal” weights, so we solve the optimization problem by repeatedly applying many steps of gradient descent, as shown on the right.

Now, directly doing gradient descent this way would quickly lead to overfitting4, so we usually regularize the problem. With LLMs, the regularization tool of choice is usually weight decay. Specifically, when using vanilla SGD5, weight decay is equivalent to having a term in the loss equal to the squared sum of the weights:

\[R(\theta)=\sum_i \sum_j[W_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+\cdots\]

Hence, the overall objective now is as follows (where \(\lambda\) is a hyperparameter controlling the strength of the weight decay):

\[\min_{\color{YellowOrange}{\theta}} \biggl[\underbrace{-\log P_{\color{YellowOrange}{\theta}}({\color{PineGreen}{y}} \mid {\color{RoyalBlue}{x}})}_{\color{BrickRed}{L}} + \frac{\lambda}{2} R({\color{YellowOrange}{\theta}})\biggr]\]

Differentiating this to objective to get the gradient, we notice the gradient update has two distinct terms6: the first corresponding to the minimizing the negative log likelihood as before, and a new second term \(-\alpha\lambda w\) that pushes the weight towards the origin \(0\).

\[<br>% https://tex.stackexchange.com/a/9477<br>\def\mathunderline#1#2{\color{#1}\underline{{\color{black}#2}}\color{black}}<br>\begin{align*}<br>&{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \left(\mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} + \mathunderline{LimeGreen}{\frac{\lambda}{2} \frac{\partial R}{\partial \color{YellowOrange}{w}}} \right)\\<br>\Rightarrow &{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \left(\mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} + \mathunderline{LimeGreen}{\lambda {\color{YellowOrange}{w}}} \right)\\<br>\Rightarrow &{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} - \alpha \mathunderline{LimeGreen}{\lambda {\color{YellowOrange}{w}}}<br>\end{align*}\]

Which means the regularized problem now looks like:

In summary, adding a squared sum of...

color yelloworange finetuning full weight weights

Related Articles