Implementing LoRA: Low-Rank Adaptation of Large Language Models from scratch (Annotated paper)
Overview
We are aware that fine-tuning large language models poses challenges due to two primary factors:
- It demands a significant quantity of GPUs, which might not always be accessible or financially feasible
.
- Pre-trained models experience catastrophic forgetting of acquired parameters when trained for a particular task, causing the initial model’s performance to decline across different tasks
.
Therefore, are there any approaches to enhance the efficiency of model fine-tuning
-
Certainly, adapter methods involving sequentially trainable parameters. However, it’s important to note that these techniques introduce extra inference latency due to the inability to parallelize sequential computations
.
-
LoRA: Low-Rank Adaptation of Large Language Models introduces a novel, parameter-efficient approach for fine-tuning. It integrates trainable parameters for fine-tuning purposes, all while avoiding any supplementary inference latency
.
I have annotated the paper LoRA: Low-Rank Adaptation of Large Language Models, covering all important aspects. You can enlarge the picture by clicking on it .
In this blog, we’re going to create a complete implementation of LoRA
using Python 🐍. While there’s a fantastic implementation by HuggingFace that you might find interesting, here we’ll focus on building everything from scratch.
Implementation
We’ll make use of the bert-base-cased
model provided by HuggingFace to apply the LoRa
technique. To provide a brief understanding of BERT
, it’s a model primarily focused on encoding, designed to predict missing words within input text. It employs a structure built on the transformer
architecture. In this post, we’ll fine-tune BERT
for sentiment analysis on movie reviews by adding a classification head onto bert-base-cased
. Furthermore, we’ll introduce a mechanism to integrate LoRA
layers into the encoder modules .
Let us install a few libraries first,
pip install torch datasets transformers evaluate
Load IMDB dataset
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
import torch.nn as nn
from torch.nn import Linear
import copy
import evaluate
import numpy as np
import math
model_type = "bert-base-cased"
imdb = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained(model_type)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
Load pre-trained dataset
model = AutoModelForSequenceClassification.from_pretrained(model_type, num_labels=2, id2label=id2label, label2id=label2id)
print(model)
The model architecture,
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(28996, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
Examining the encoder module, we observe a total of 12 BertLayers. Within each layer’s self-attention mechanism, there exist weights for query (WQ), value (WV), and key (WK).
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
...
)
)
)
The LoRA Concept
According to the research paper, the training process focuses exclusively on weight matrices A and B, while keeping all other weights in a frozen state.
WQ_new = WQ_old * x + B1A1 * x
WV_new = WV_old * x + B2A2 * x
WK_new = WK_old * x + B3A3 * x
What should the configuration of these weight matrices resemble
- Matrix
A
possesses dimensions ofd*r
, whered
signifies the dimension of weight matrixW
, andr
represents the new rank, which can be8
or16
. You can consider this as a hyperparamter to be tuned. We initiateA
with Gaussian initialization. - Matrix
B
holds dimensions ofr*d
, whered
indicates the dimension of weight matrixW
, andr
corresponds to the projected dimension of MatrixA
. Initially, all elements inB
matrix are set to0
.
Implementing LoRA module
Below is the PyTorch
code that puts into practice the concepts I just explained .
class LoRAModule(nn.Module):
def __init__(self, layer, r=8, alpha=16):
super().__init__()
# Store the original layer
self.W = layer
# Initialize LoRa parameters
self.LoRA_A = Linear(in_features=768, out_features=r, bias=False)
self.LoRA_B = Linear(in_features=r, out_features=768, bias=False)
# Initialize LoRa parameters' weights
self.reset_params()
# Store the scaling factor for LoRa
self.scaling_factor = alpha / r
def reset_params(self):
# Initialize LoRA_A weights using Kaiming uniform initialization
nn.init.kaiming_uniform_(self.LoRA_A.weight, a=math.sqrt(5))
# Initialize LoRA_B weights with zeros
nn.init.zeros_(self.LoRA_B.weight)
def forward(self, x):
# Apply LoRa transformation: W(x) + LoRA_B(LoRA_A(x)) * scaling_factor
x = self.W(x) + self.LoRA_B(self.LoRA_A(x)) * self.scaling_factor
return x
Updating the original BERT model with LoRA module
Before we add the LoRA module, we make a copy of the original model. This copy helps us merge weights to avoid inference latency, which I’ll explain more about in the later part of this article
.
def update_model_with_lora_weights(model):
# Loop through the 12 BertLayer instances in the encoder
for i in range(12):
# Replace the query component of self-attention with LoRAModule by retaining weights
model.bert.encoder.layer[i].attention.self.query = LoRAModule(
model.bert.encoder.layer[i].attention.self.query
)
# Replace the value component of self-attention with LoRAModule by retaining weights
model.bert.encoder.layer[i].attention.self.value = LoRAModule(
model.bert.encoder.layer[i].attention.self.value
)
# Return the updated model with LoRA weights
return model
model_copy = copy.deepcopy(model)
model_with_lora = update_model_with_lora_weights(model_copy)
print(model_with_lora)
Now, we observe that the original BERT model has been integrated with the LoRA Module.
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(28996, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): LoRAModule(
(W): Linear(in_features=768, out_features=768, bias=True)
(LoRA_A): Linear(in_features=768, out_features=8, bias=False)
(LoRA_B): Linear(in_features=8, out_features=768, bias=False)
)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): LoRAModule(
(W): Linear(in_features=768, out_features=768, bias=True)
(LoRA_A): Linear(in_features=768, out_features=8, bias=False)
(LoRA_B): Linear(in_features=8, out_features=768, bias=False)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
Setting trainable parameters
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for name, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
print_trainable_parameters(model_with_lora)
trainable params: 108606722 || all params: 108606722 || trainable%: 100.0
As we see, all the weights in the model are configured to be trainable. However, we only need to designate the weights of the LoRA
and classification head
components as trainable .
def set_trainable_params(model):
for name, param in model.named_parameters():
if not "LoRA" in name and not "classifier" in name:
param.requires_grad = False
set_trainable_params(model_with_lora)
print_trainable_parameters(model_with_lora)
trainable params: 296450 || all params: 108606722 || trainable%: 0.27295732210755796
We only need to train 0.27%
of the total parameters .
Setup evaluation metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
Setup training procedure
training_args = TrainingArguments(
output_dir="bert-cased-lora-imdb",
learning_rate=1e-3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
)
trainer = Trainer(
model=model_with_lora,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
Now we’re fully prepared to train the model !
Setup training procedure
trainer.train()
Results
Epoch Training Loss Validation Loss Accuracy
0 0.719200 0.751393 0.500000
1 0.267300 0.209935 0.917600
2 0.213700 0.203721 0.925000
3 0.153000 0.209234 0.929640
The loss is going down after each epoch. I trained it for 3 epochs, but you can train it more until the model converges.
We are not done yet …
Merging the weights to reduce inference latency
Let us look back at the equation,
WQ_new = WQ_old * x + B1A1 * x
After learning the parameters A1 and B1, we can merge the two computations into a single process .
WQ_new = (WQ_old + B1A1) * x
This straightforward technique enables LoRA
Model inference without any additional latency.
Let’s modify the code to achieve the same objective. We will proceed to combine the weights within the original model
not model_with_lora
.
def merge_lora_weights(model, model_with_lora):
for i in range(12):
# Get the LoRA and original self-attention modules for the i-th layer
lora_module = model_with_lora.bert.encoder.layer[i].attention.self
org_module = model.bert.encoder.layer[i].attention.self
# Compute merged weights for the query component
lora_query_weights = torch.matmul(
lora_module.query.LoRa_B.weight, lora_module.query.LoRa_A.weight
)
merged_query_weights = torch.add(lora_module.query.W.weight, lora_query_weights)
org_module.query.weights = merged_query_weights
# Compute merged weights for the value component
lora_value_weights = torch.matmul(
lora_module.value.LoRa_B.weight, lora_module.value.LoRa_A.weight
)
merged_value_weights = torch.add(lora_module.value.W.weight, lora_value_weights)
org_module.value.weights = merged_value_weights
merge_lora_weights(model, model_with_lora)
The weights in the model
have been modified. You can now employ this updated model
for your inference tasks.
Congratulations! You have now successfully implemeted LoRA from scratch .