A Deep-Dive Into Fine-Tuning
0

How Fine-Tuning Works

Inputs and Outputs

Now that we've covered the basics of fine-tuning, let's dive deeper. Unlike pre-training which uses a training technique called self-supervised learning, fine-tuning uses a more traditional machine learning approach called supervised learning. As such, we often call this process supervised fine-tuning or just SFT.

In supervised learning, we provide the model with a dataset of examples, each comprised of an input and an output. Each output should showcase the behavior that we'd like the model to learn given its corresponding input prompt. As we add more and more examples to our dataset, the model will develop inductive biases that allow it to generalize to unseen examples.

For example, if we wanted to teach an LLM a one-sentence summarization task, here is what the input and output might look like:

JSON
{
	"input": "Penguins are fascinating flightless birds that are primarily found in the Southern Hemisphere, especially in Antarctica. Characterized by their unique black and white plumage, resembling a formal tuxedo, these birds are not capable of flight. Instead, they are excellent swimmers, using their wings as flippers to navigate underwater. Their diet predominantly consists of fish, squid, and a variety of other small sea creatures, which they hunt underwater. The way they adapt to marine life while living in some of the coldest environments on Earth is remarkable.",
	"output": "Penguins, with black and white plumage, are flightless, adept swimmers from the Southern Hemisphere, feeding mainly on fish and squid."
}

Or maybe we'd like to fine-tune the model to answer questions in Shakespearean English using a dataset of examples like the following:

JSON
{
	"input": "What is the capital of France?",
	"output": "The fair capital of France doth be Paris."
}

How about one more example to bring the point home? Imagine we're building a customer service chatbot that should classify customer requests into one of three categories ("technical support", "billing", and "general inquiries") and do so in a strict JSON format so that it can be easily parsed by downstream services. Here's what a single example in our dataset might look like:

JSON
{
	"input": "I'm having trouble adding a new payment method to my account.",
	"output": "{\"category\": \"billing\"}"
}

Next-Token Prediction

Once we've finished collecting our fine-tuning dataset, we train the model using a task called next-token prediction (NTP). As the name suggests, we feed the model the sequence of input tokens and then incrementally predict each token from the output given all the tokens before it (this step-by-step prediction is called autoregression). Once we reach the end of the output, we calculate a loss function that quantitatively measures how far off the model's prediction was for each output token. Lastly, we use an algorithm called backpropagation to update the model's parameters to minimize this loss in the future (i.e. steer the model towards the correct output).

Next-Token Prediction

Let's consider the above scenario in which we fine-tune an LLM to always answer questions starting with "The answer is:". In this particular example from our fine-tuning dataset, our input is "What is 1 + 1?" and our output is "The answer is: 2". To train the model, we feed the input tokens into the model and then autoregressively try to predict each output token. First, we try to predict the token The given the prompt What is 1 + 1?. Next, we try to predict the token answer given the prompt What is 1 + 1? The. And so on until we reach the end of the output.

Look at the second token prediction step. If the model were to predict the token solution instead of the correct token answer, it would incur a larger loss. To minimize this loss, we'd update the model's parameters to steer future predictions back towards answer. By repeating this process on each token in the example's output and for thousands of diverse examples whose outputs all start with "The answer is:", the model will adopt this behavior and learn to generalize to new examples.

Hyperparameters and Splits

When fine-tuning LLMs, there are 3 primary hyperparameters that we need to tune. We will explore each of these in more detail later when we build our own fine-tuning pipelines:

🔎 A hyperparameter is a fancy name for a setting that we can adjust to control the model's training process.
  • Epochs: This is simply the number of times that the model will be trained on the entire dataset. If we are learning a more complex task on a larger dataset, we'll typically opt for more epochs than a simpler task on a smaller dataset. However, more epochs also means more compute time and more risk of overfitting.
  • Batch Size: This is the number of examples that are fed into the model in a single training iteration. Larger batch sizes lead to faster training but will also use more memory and may lead to overfitting.
  • Learning Rate: This determines how quickly the model adjusts its weights. A learning rate that is too high may cause the model to converge too quickly to a suboptimal solution, while a learning rate that is too low may cause the model to take too long to converge or get stuck in a local minimum.

All of these hyperparameters are interdependent and must be tuned together. One strategy for doing so is to divide the dataset among separate train and validation splits. The train split contains the majority of the dataset and is used to fine-tune the model, while the validation split is used to evaluate the model's unbiased performance throughout the training process. By evaluating how the loss on both splits changes over each epoch, we can assess whether the model is underfitting or overfitting and adjust the hyperparameters accordingly.

Hosted or Local

Last but not least, let's consider how you might go about fine-tuning an LLM. In practice, there are two main fine-tuning strategies that you can follow, each with its own complexity and flexbility trade-offs:

  1. Low-code, hosted services: These are services that provide a simple, user-friendly interface for fine-tuning LLMs. They abstract away the training routine and compute resources, allowing you to fine-tune a model with just a few console clicks or API calls. Examples of such services include OpenAI for fine-tuning their proprietary models and Together AI for fine-tuning open-source models. These services make fine-tuning super easy and inexpensive, but also limit the amount of control you have.
  2. Open-source frameworks: The alternative is to fine-tune the model yourself using an open-source framework like Hugging Face Transformers. More importantly, you will need to BYOC (Bring Your Own Compute) by either purchasing physical GPUs for your local machine or more realistically renting cloud GPUs from services like RunPod or Paperspace. This approach is far more flexible and allows for more fine-grained control over the fine-tuning process, but it also requires a deeper understanding of how LLMs are implemented and trained (and the headache of managing your own compute resources).
🔎 In the subsequent lessons, we'll fine-tune our own LLMs using both hosted fine-tuning services and open-source frameworks. Stayed tuned!