In this tutorial we will learn how to use QLORA [Dettmers23] to fine-tune a LLM for QA tasks.#
The notebook shows an example with Falcon-7B. In practice you can also try larger LLMs e.g. GPT-NeoX-20B etc.
Step 0: Prepare a Colab Environment to run this tutorial on GPUs#
Make sure to “Enable GPU Runtime” by following this url. This step will make sure the tutorial runs faster.
Step 1: Do all the necessary pip installs#
[ ]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install scipy
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
!pip install einops # needed for falcon
Step 2: Do the necessary imports and instantiate a model from the HuggingFace model hub.#
[ ]:
import torch
import bitsandbytes
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
[ ]:
#model_id = "EleutherAI/gpt-neox-20b"
model_id="ybelkada/falcon-7b-sharded-bf16"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
Step3: Initialize PEFT based QLORA training#
[ ]:
from peft import prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
[ ]:
print(model)
Step 4: Get a QA dataset.#
Here we get the SQuAD v2 with answerable questions.
[ ]:
from datasets import load_dataset
qa_dataset = load_dataset("squad_v2")
Step 5: Create a prompt.#
[ ]:
def create_prompt(context, question, answer):
if len(answer["text"]) < 1:
answer = "Cannot Find Answer"
else:
answer = answer["text"][0]
prompt_template = f"### CONTEXT\n{context}\n\n### QUESTION\n{question}\n\n### ANSWER\n{answer}</s>"
return prompt_template
mapped_qa_dataset = qa_dataset.map(lambda samples: tokenizer(create_prompt(samples['context'], samples['question'], samples['answers'])))
Step 6: Start QLORA fine-tuning.#
We only show an example here for 100 steps. You can run this for longer to get a stable QA model.
[ ]:
import transformers
# needed for gpt-neo-x tokenizer
tokenizer.pad_token = tokenizer.eos_token
trainer = transformers.Trainer(
model=model,
train_dataset=mapped_qa_dataset["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=100,
max_steps=100,
learning_rate=1e-3,
fp16=True,
logging_steps=1,
output_dir="~/path_to_some_output/qlora/outputs",
optim="paged_adamw_8bit",
report_to="none" # turns off wandb
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
Step 7: Do inference with the trained model.#
[ ]:
from IPython.display import display, Markdown
def make_inference(context, question):
batch = tokenizer(f"### CONTEXT\n{context}\n\n### QUESTION\n{question}\n\n### ANSWER\n", return_tensors='pt', return_token_type_ids=False)
with torch.cuda.amp.autocast():
output_tokens = model.generate(**batch, max_new_tokens=50)
display(Markdown((tokenizer.decode(output_tokens[0], skip_special_tokens=True))))
[ ]:
# this is an example of "answerable" question given a context
context = "Cheese is the best food."
question = "What is the best food?"
make_inference(context, question)
[ ]:
# this is an example of "unanswerable" question given a context
context = "Cheese is the best food."
question = "How far away is the Moon from the Earth?"
make_inference(context, question)
Congratulations 🎉✨🎊🥳 !! You can now fine-tune a LLM with PrimeQA and QLORA.