Join our upcoming webinar “Deriving Business Value from LLMs and RAGs.”
Register now

LLM pruning & distillation: Retraining LLMs to SLMs with Minitron approach

Thank you for subscribing to our newsletter!
Oops! Something went wrong while submitting the form.

LLM providers often build a whole family of models from scratch, each differing in size to meet various needs. For example, Llama 3.1 comes in versions with 8B, 70B, or 405B parameters. This variety serves different deployment requirements and computing budgets. However, training multiple massive models from the ground up is incredibly time-consuming and resource-intensive. Recently, researchers have found that by combining weight pruning with knowledge distillation, we can significantly reduce the costs of training these model families. Instead of starting from scratch every time, they fully train only the largest model. Then, they create the smaller models by gradually trimming down (pruning) the big one and using knowledge distillation to regain the accuracy of the pruned versions.

In this article, we'll explore how model pruning and distillation work and how you can achieve a small language model (SLM) from an LLM. We'll look at a recent research paper by Nvidia that developed the MINITRON models, which outperform many top models thanks to pruning and distillation techniques.

LLM knowledge distillation

Knowledge distillation for large language models (LLMs) is like teaching a smaller, more efficient student to mimic a larger, highly knowledgeable teacher. Essentially, you have a big, complex language model that knows a lot and can do many things. But this big model can be too slow or expensive to use all the time. So, you create a smaller model that's faster and cheaper to run.

The process involves training this smaller model to replicate the performance of the big model as closely as possible. The big model answers questions or performs tasks, and these answers or outputs are used as examples or "lessons" for the smaller model. The goal is for the smaller model to learn to give the same answers or perform tasks just as well but with less computing power and faster responses.

llm knowledge distillation

LLM pruning

LLM pruning is the process of simplifying a large and complex model by selectively removing parts that are less important. Just like you trim a tree to shape it and enhance its growth, you prune an LLM to resize and optimize it. This makes the model smaller and faster to run without greatly affecting its performance.

The process starts with a full-sized model that has a lot of different layers, or "neurons," that each play a role in how the model processes information. By analyzing which neurons are used less or contribute less to the model's decisions, these can be removed or "pruned." The aim is to strip down the model to its most essential parts, making it leaner.

As shown in the figure below, we start the pruning process by first computing the importance of each layer, neuron, head, and embedding dimension and then sorting these importance scores to compute a corresponding importance ranking.

llm pruning distillation
High-level overview of iterative pruning and distillation approach to train a family of smaller LLMs: Source

This trimmed-down model requires less memory and computing power, which means it can operate faster and more efficiently.

Importance estimation

The research introduces an activation-based strategy to estimate the importance of various model components using a small calibration dataset and only forward propagation. This method assesses the importance of each head, neuron, and embedding channel by analyzing the activations from the multi-head attention (MHA), multilayer perceptron (MLP), and LayerNorm layers.

For depth pruning, the study evaluates layer importance using three metrics: LM validation loss, Block Importance (BI), and a loss-based ranking where a single or a block of contiguous layers is removed to see its effect on LM loss. BI measures the cosine distance between the input and output of a layer or block of layers. The findings suggest that while BI and LM loss metrics are closely related, they do not always yield the most accurate models for downstream tasks, as evidenced by results in the cited figures. Layer importance is also assessed using the Winogrande benchmark.

Retraining LLMs to SLMs with knowledge distillation and pruning

In the paper, authors discuss two methods to improve the accuracy of a simplified model after it has been scaled down, which is referred to as retraining.
The first method involves conventional training with ground truth data labels. The second method, called retraining with knowledge distillation. We’ll explore the second, novel approach.

Retraining with knowledge distillation

As we discussed, knowledge distillation involves learning from a larger and more complex model, known as the teacher, by a smaller, simpler model, known as the student. The smaller model tries to imitate the larger model's behavior and its intermediate processes.

Knowledge distillation uses a special approach to measure how well the student model is learning from the teacher. This involves comparing the outputs and processes of both models at various stages and tweaking the student model based on the differences found. This method also considers multiple factors and settings to find the best way for the student model to learn effectively.

The output probability distribution of an LLM for a given token xi is computed as:

llm output probability distribution

where τ is the softmax temperature and |V| is the vocabulary size. Logit-based KD loss across the sequence of all output tokens is represented as

logit based kd loss

Here, pkt (x, τ) and pks (x, τ) are the teacher and student probability distributions on the k-th token, respectively. and l is the sequence length.

In the distillation process, the authors test different ways to measure and minimize loss. They also try out various combinations of steps and adjustments within the Transformer model, each with its own pros and cons.

kd loss

Here, hkit and hkis stand for the hidden states of the teacher and student models for the i-th element, and l is the total number of elements we're looking at. H represents the chosen intermediate stages for our study. Any differences between the student and teacher models are smoothed out by applying a linear transformation during the learning process to scale up the student's hidden states to match the teacher's. These hidden states are always adjusted with a technique called LayerNorm.

The total error, represented as L, is calculated by adding up the cross-entropy loss of the student model LCLM, the logit loss Llogits, and a weighted term α times Lis. Here, LCLM is the error based on the accurate labels, and α helps balance the weight of different error terms. Since the sizes of Llogits and Lis can vary quite a bit, we dynamically adjust α based on these values to get better outcomes than if α were fixed.

The figure below shows the calculated error for distillation based on different specific stages within the Transformer model.

Overview of Distillation. A student model with N layers is distilled from a teacher model with M layers. The student learns by minimizing a combination of embedding output loss, logit loss and transformer encoder specific losses mapped across student block S and teacher block T: Source

Minitron family of LLMs

The researchers reduced the size of the Nemotron-4 15B to produce two smaller versions, one with 8 billion and another with 4 billion parameters. These smaller models, referred to as MINITRON, required up to 40 times fewer training tokens compared to training from scratch. This resulted in significant compute cost savings and reduced the environmental impact associated with training such large models.

Performance comparison

The MINITRON models were benchmarked against similarly sized models and previous versions. MINITRON 8B showed competitive or superior performance compared to well-known models like LLaMa-2 7B and Mistral 7B, despite the drastic reduction in training resources.

minitron 8b performance

Similarly, MINITRON 4B held its ground against smaller community models, proving that the pruning and retraining techniques did not sacrifice capability for efficiency.

minitron 4b performance
MINITRON 4B model compared to similarly-sized community models

The MINITRON models were also compared with other advanced models that had been pruned in similar ways, like LLM-Pruner, SliceGPT, LaCo, ShortGPT, and Sheared LLaMa. These comparisons showed that the MINITRON models outperformed the larger and heavily pruned models.

minitron 4b 8b
Pruned MINITRON 8B model compared to multiple baselines: the original Nemotron-4 15B, the previous generation Nemotron-3 8B, and multiple community models.

Key takeaways

In this article, we looked at how pruning and distillation can turn LLMs into smaller, more manageable models (SLMs). Typically, creating different sizes of models to meet various needs requires a lot of resources because each model is built from scratch. However, the recent development of the MINITRON models has changed this. By only fully training the largest model and then making smaller versions through pruning and distillation, costs and resources are drastically reduced. Importantly, these MINITRON models have even outperformed many leading models, proving that this method is making a great approach for future model development.

Recommended for you

Stay connected

Subscribe to receive new blog posts and latest discoveries in the industry from SuperAnnotate
Thank you for subscribing to our newsletter!
Oops! Something went wrong while submitting the form.