LoRA: Reducing Trainable Parameters
2

LoRA: 0 to 100

Alright, we've ascertained that performing full fine-tuning on an LLM involves updating a massive number of parameters and requires a lot of compute... more compute than most of us have access to or are willing to pay for. So that's a non-starter.

But what if there was a way to reduce the number of parameters that need to be updated at each training step (we call these trainable parameters)? What if we could take a model with 7 billion parameters and change its behavior by fine-tuning only 30 million parameters? These is precisely what LoRA aims to accomplish.

Brief Background

Believe it or not, the idea of reducing the number of trainable parameters in a model is not new. For many years, we've employed techniques like pruning to strip out unused parameters from the model and parameter sharing to group parameters together.

In the wheelhouse of LLM fine-tuning, it is not uncommon to see techniques like layer freezing where we only update parameters in the final few layers of the model. And prefix tuning where we train special "prefix" tokens that are prepended to the actual input tokens and guide the model's predictions.

However, the common denominator between these techniques is that they don't modify the model holistically. This is a significant limitation, as the specific circuitry that is most relevant to our fine-tuning task might be located anywhere within the model and typically involves a complex interplay between many layers.

LoRA offers an alternative — using clever linear algebra tricks, we can actually update an entire model with a fraction of the parameters. Crazy, right? To understand how this is possible, it is necessary to first understand LoRA's mathematical underpinning: low-rank factorization.

🔎 The next section is quite math-heavy. If you're new to linear algebra, take your time to understand the concepts and pull in additional resources. It's worth it!

Low-Rank Factorization

Matrix Rank

The rank of a matrix, put simply, is a measure of its information density (this is a crude analogy so don't get mad at me). If a matrix has a rank of 1, it means that all of the information in the matrix can be encoded in a single vector. If the matrix has a very high rank, it means that it encodes a lot more information. Take the following matrix for example:

A=[123246369]A = \begin{bmatrix} 1 & 2 & 3 \\ 2 & 4 & 6 \\ 3 & 6 & 9 \\ \end{bmatrix}

This matrix has a rank of 1 since the second and third columns are scalar multiples of the first column! The second column is just the first multiplied by 2, and the third column is just the first multiplied by 3. This effectively means that all the information in the matrix can be encoded in a single vector equal to the first column: [1,2,3][1, 2, 3]. Now consider a second matrix:

B=[111123246]B = \begin{bmatrix} 1 & 1 & 1 \\ 1 & 2 & 3 \\ 2 & 4 & 6 \\ \end{bmatrix}

What do you think the rank of this matrix is? It's 2! Notice that it is impossible to create the first row using a combination of the second and third rows. Put formally, there is no linear combination of the rows that can form the first row, meaning the first row is linearly independent of the other two. On the other hand, the second and third rows are linearly dependent since the third row is just the second row multiplied by 2. We can therefore encode the matrix's information using two vectors, [1,1,1][1, 1, 1] and [1,2,3][1, 2, 3], giving us a rank of 2.

This brings us to an actual definition: the rank of a matrix is the maximum number of linearly independent rows or columns in the matrix. If the matrix has a rank of rr, it means that its contents can be expressed using only rr linearly independent vectors, as we did above.

We say that a matrix AA is full-rank if it has the maximum possible rank. In other words, the rank satisfies rank(A)=min(m,n)\text{rank}(A) = \min(m, n), where mm is the number of rows and nn is the number of columns. In this case, every row or column is linearly independent of the others, and the matrix encodes the maximum amount of information.

Conversely, we call a matrix low-rank (or rank-deficient) if it has a rank significantly less than the maximum possible rank: rank(A)min(m,n)\text{rank}(A) \ll \min(m, n). For example, the first matrix we looked at would be considered low-rank since its rank of 1 is less than its maximum possible rank of 3!

Properties of Matrix Rank

Very quickly before moving on, there are two key properties of matrix rank that we need to know. We will lean on these properties later to understand exactly how LoRA works:

  1. The rank of a matrix AA is constrained by the minimum of its number of rows mm and columns nn. So if you have a matrix with dimensions (3,2)(3, 2), its maximum possible rank would be 2. This is fairly self-evident from the definition of a full-rank matrix, but it's worth repeating:
rank(A)min(m,n)\text{rank}(A) \leq \min(m, n)
  1. For two matrices AA and BB, the rank of their product ABAB is constrained by their individual ranks. Intuitively, when we combine two matrices, the resulting matrix will encode only as much information as the least informative of the two matrices:
rank(AB)min(rank(A),rank(B))\text{rank}(AB) \leq \min(\text{rank}(A), \text{rank}(B))

Rank Factorization

Now for the fun part. The rank factorization of a matrix AA is a way to factor it into two smaller matrices, BB and CC. Specifically, if AA has rank rr and dimensions (m,n)(m, n), then there always exists a matrix BB with dimensions (m,r)(m, r) and a matrix CC with dimensions (r,n)(r, n) such that the following matrix product holds:

A=BCA = BC

Using this rule, the rank factorization of our first matrix AA with dimensions (3,3)(3, 3) and rank 1 would result in the a (3,1)(3, 1) matrix BB and a (1,3)(1, 3) matrix CC:

A=[123246369]=[123][123]A = \begin{bmatrix} 1 & 2 & 3 \\ 2 & 4 & 6 \\ 3 & 6 & 9 \\ \end{bmatrix} = \begin{bmatrix} 1 \\ 2 \\ 3 \\ \end{bmatrix} \begin{bmatrix} 1 & 2 & 3 \\ \end{bmatrix}

That was a lot at once, so let's break it down step by step using a concrete example. Imagine we have a matrix XX with dimensions (100,1000)(100, 1000) that happens to have a matrix rank rXr_X of 10. From the first property that we learned above, we know that the maximum possible rank of XX is min(m,n)\min(m, n). In this case, this corresponds to the number of rows, or 100. This means that XX is very low-rank (1010010 \ll 100), encoding only a fraction of the information that it theoretically could.

So let's factorize XX into two smaller matrices YY and ZZ such that X=YZX = YZ. This is typically done using an algorithm called singular value decomposition (SVD), but that's a topic for another day. The key point is that YY will have dimensions (m,rX)=(100,10)(m, r_X) = (100, 10) and ZZ will have dimensions (rX,n)=(10,1000)(r_X, n) = (10, 1000), meaning that the number of elements has decreased from 100,000 in XX to 11,000 in YZYZ. Therefore, we have successfully represented the same amount of information as in XX but with only 11% of the original data volume (number of matrix elements)!

Note that YY and ZZ are always full-rank matrices. From our first property, we know that their maximum possible rank is rXr_X due to their dimensions. But from the second property, we know that the rank of YY and ZZ set a ceiling on the rank of their product YZYZ. Since we know YZYZ's rank is rXr_X (because it is equal to XX), this means that both YY and ZZ must have a minimum rank of rXr_X. Putting both properties together, YY and ZZ assume a rank of exactly rXr_X and, since this equals their smaller dimensions, they are full-rank.

Let's quickly intuit why all this is even possible. The XX matrix is very low-rank — it encodes far less information than it theoretically could at its size. This suggests that there are many repeated patterns in the data (like our first low-rank matrix from the previous section) that we can factor out using clever algorithms like SVD. In other words, we replace the noisy XX matrix with two smaller, high-signal matrices. We call this technique of rank factorization on low-rank matrices... low-rank factorization. We'll return to this momentarily.

Parameter Matrices

Now back to the world of deep learning. Before we can tie this all together, we need to understand how the parameters of a neural network (and therefore LLMs) are typically represented. If you guessed "as matrices", you're exactly right.

🔎 You're about to witness the world's fastest crash course on neural networks. This detour probably doesn't do justice to the complexity behind neural networks, so if you get lost, feel free to skip to the takeaway in the last paragraph.

The image below depicts a traditional 3-layer neural network. The input layer has 3 neurons, the hidden layer has 6 neurons, and the output layer has 1 neuron. The input layer doesn't actually perform any computation, it just describes the dimensionality (or number of features) of the input data. The second and third layers are where the magic happens: we call these linear (or dense or fully connected) layers, and they're responsible for performing the actual computations.

Neural Network

In a linear layer, each neuron is connected to every neuron in the preceding layer, using a simple linear equation. Consider the first neuron in the second layer (the first green circle). It takes the three values from the input layer, x1x_1, x2x_2, and x3x_3, and multiplies each by its corresponding weights, w1w_1, w2w_2, and w3w_3. Here's what this looks like, where bb is an additional trainable parameter called the bias and yy is the output of that first neuron:

y=w1x1+w2x2+w3x3+by = w_1 \cdot x_1 + w_2 \cdot x_2 + w_3 \cdot x_3 + b
🔎 You might be confused about the interplay between the terms "parameters" and "weights". Typically, parameters refer to any trainable value in the model, which includes both weights and biases. In practice, however, the two terms are often used interchangeably.

Guess what! We can represent this equation using a vector dot product, where w\vec{w} is a vector of the weights and x\vec{x} is a vector of the inputs:

y=wx+by = \vec{w} \cdot \vec{x} + b w=[w1w2w3]x=[x1x2x3]\vec{w} = \begin{bmatrix} w_1 \\ w_2 \\ w_3 \\ \end{bmatrix} \quad \vec{x} = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ \end{bmatrix}

So far, we've only modeled the output of a single neuron in the second layer. Can we instead model the outputs of all 6 neurons in the second layer? Yes, using matrix multiplications:

Y=XW+bY = XW + b
  1. Our input vector x\vec{x} becomes a (1,3)(1, 3) input matrix XX. If we wanted to pass multiple inputs to the layer at once, XX would become a (B,3)(B, 3) input matrix, where BB is the batch size.
  2. Our weight vector w\vec{w} becomes the first column of our (3,6)(3, 6) weight matrix WW, which contains the weights of all 6 neurons in the second layer. Each neuron has its own column. The dots in the matrix represent the weights of the other neurons.
  3. Our bias bb becomes the first column of a (1,6)(1, 6) bias matrix bb. Each neuron has a single bias. Similarly to the weights matrix, the dots in the matrix represent the biases of the other neurons.
  4. The output of the second layer is then a (1,6)(1, 6) output matrix YY: each column contains the output of a single neuron. Notice that the first column of YY contains the same output that we derived above for the first neuron. If we passed multiple inputs to the layer at once, YY would become a (B,6)(B, 6) output matrix.
Y=[x1x2x3][w1w2w3]+[b]Y = \begin{bmatrix} x_1 & x_2 & x_3 \\ \end{bmatrix} \begin{bmatrix} w_1 & \dots \\ w_2 & \dots \\ w_3 & \dots \\ \end{bmatrix} + \begin{bmatrix} b & \dots \\ \end{bmatrix} Y=[w1x1+w2x2+w3x3+b]Y = \begin{bmatrix} w_1 \cdot x_1 + w_2 \cdot x_2 + w_3 \cdot x_3 + b & \dots \\ \end{bmatrix}

Congratulations — if you've read this far, we've just devised a clean formula to represent the outputs of an entire linear layer! Following this approach, the third layer would be represented as a (6,1)(6, 1) weight matrix W2W_2 and a (1,1)(1, 1) bias matrix b2b_2. In practice, we would then wrap this output in a non-linear activation function like ReLU or tanh before passing it to the next layer:

Y=tanh(XW+b)Y = \text{tanh}(XW + b)
🔎 See if you can figure out the pattern in the dimensions of these matrices. Hint: the first dimension of WW is the number of neurons in the previous layer and the second dimension is the number of neurons in its own layer.

The primary takeaway here is that we can represent each layer as a set of parameter matrices. In our simple 3-layer network, the parameters are primarily contained in the weight matrices W1W_1 and W2W_2. And this is also true for LLMs! All the parameters of LLMs, for both attention and feedforwards layers, are represented as parameter matrices.

Putting LoRA Together

Key Insight

Consider an arbitrary parameter matrix WW of an LLM that we're fine-tuning. Since full fine-tuning will update every parameter in the model, we can represent the final fine-tuned WW' matrix as the element-wise sum of our original WW matrix and a new matrix ΔW\Delta W that represents the updates for each parameter. In this reframing, WW is frozen and ΔW\Delta W contains the deltas from all the newly fine-tuned parameters:

W=W+ΔWW' = W + \Delta W

The key insight behind LoRA is that the ΔW\Delta W matrix is low-rank. In other words, ΔW\Delta W has a low information density. The fine-tuning process isn't actually leveraging the flexibility of all the available trainable parameters and is instead re-learning the same patterns over and over again in ΔW\Delta W. Here's how Hu et al. explained this finding in the original LoRA paper:

We take inspiration from Li et al. (2018a); Aghajanyan et al. (2020) which show that the learned over-parametrized models in fact reside on a low intrinsic dimension. We hypothesize that the change in weights during model adaptation also has a low “intrinsic rank”, leading to our proposed Low-Rank Adaptation (LoRA) approach.

Why does this make sense? During pre-training, we assume that WW will be maximally expressive (i.e. a full-rank matrix) to properly capture the complexity of language from a massive dataset. However, when we fine-tune the pre-trained model on a new task, not all of the model's parameters need to be updated — the complexity of a single task pales in comparison to that of learning an entire language via pre-training. In fact, only a fraction of the parameters in ΔW\Delta W will include meaningful updates.

Similar to our example at the end of the rank factorization section, we can therefore factorize the contents of the ΔW\Delta W matrix into two smaller matrices MM and NN that capture the same information with fewer parameters:

W=W+ΔW=W+MNW' = W + \Delta W = W + MN

If we plug this into our two-layer neural network from earlier, we get an updated version of the layer output equation that uses LoRA:

Y=XW+bY = XW + b Y=X(W+ΔW)+bY = X(W + \Delta W) + b Y=X(W+MN)+bY = X(W + MN) + b

Now, instead of fine-tuning all the parameters in ΔW\Delta W, we only need to fine-tune those of MM and NN! We call this low-rank adaptation (LoRA) since we're performing model "adaptation" (i.e. fine-tuning) using low-rank matrices.

Y=XW+XMN+bY = XW + XMN + b

Note also that we can distribute XX due to the distributive property of matrix multiplication as seen above. This is an important feature of LoRA since it means that we don't have to modify the pre-trained model (in WW) at all. We can store the LoRA matrices separately from the base model and then simply apply the LoRA matrices during the forward pass. We'll see why this is important later.

Choosing the Dimensions

You might now be asking, "how do we pick the dimensions of MM and NN?" The answer is that there's a trade-off. Let's try to understand where this trade-off comes from.

Let ΔWopt\Delta W_{\text{opt}} be the matrix, with rank roptr_{\text{opt}}, that optimally learns the task under full fine-tuning conditions. Similarly, let ΔWLoRA\Delta W_{\text{LoRA}} be the matrix that we factorize into MM and NN under LoRA fine-tuning conditions. We give MM and NN the dimensions (m,rLoRA)(m, r_{\text{LoRA}}) and (rLoRA,n)(r_{\text{LoRA}}, n) respectively, where rLoRAr_{\text{LoRA}} is the rank that we choose for our MM and NN matrices.

From our first and second matrix rank properties, we know that the product ΔWLoRA=MN\Delta W_{\text{LoRA}} = MN will have a rank of at most rLoRAr_{\text{LoRA}}. In other words. Therefore, if rLoRA<roptr_{\text{LoRA}} < r_{\text{opt}}, the rank of ΔWLoRA\Delta W_{\text{LoRA}} will be less than the rank of ΔWopt\Delta W_{\text{opt}}. This means that MM and NN won't have the size (and therefore parameter count) to sufficiently capture the information that full fine-tuning optimally learns.

Conversely, if rLoRA>roptr_{\text{LoRA}} > r_{\text{opt}}, we will actually have more parameters than necessary. In other words, the full fine-tuned ΔWopt\Delta W_{\text{opt}} matrix would possess a rank less than ΔWLoRA\Delta W_{\text{LoRA}}, so we be using too much representation power in MM and $N.

LoRA Trade-Off

We therefore face a trade-off. If you care about significantly reducing the number of trainable parameters (i.e. parameter efficiency), you might choose a low rLoRAr_{\text{LoRA}} at the expense of compromising the model's ability to learn. If you care about preserving the model's learning capacity, you might choose a high rLoRAr_{\text{LoRA}} at the expense of the computational burden of training more parameters.

In practice, we want to aim for rLoRAroptr_{\text{LoRA}} \approx r_{\text{opt}}. To get there, you will need to pick a point on the trade-off curve that is relevant to your use case and then iteratively adjust rLoRAr_{\text{LoRA}} based on the fine-tuned model's performance.