Fine Tuning Gemma-2b to Solve Math Problems

Rubens Zimbres
Google Developer Experts
6 min readApr 11, 2024

--

Mathematical word problem-solving has long been recognized as
a complex task for small language models (SLMs). To reach a good level of performance with these models, researchers often train SLMs to generate Python code or by using ensembling techniques, associated with consensus or majority vote. The challenge here is to use Google’s Gemma model, with less than 2 billion parameters and with safeguards against generating code to solve these Grade School Math problems.

Here I will use Microsoft’s Orca-Math dataset, a high quality synthetic dataset of 200K math problems obtained through a multi-agent setup where agents collaborate to create the data. Details about the model and dataset Microsoft researchers used can be found in this article, from February, 2024.

This article is quite straightforward, as we will fine tune Google’s open model Gemma-2b via HuggingFace and PyTorch. Here, I used Google Colab with a single T4 GPU.

Let’s code: first, we install the necessary libraries in the Colab environment:

pip install -q -U torch datasets sentence_transformers
pip install -q accelerate==0.27.2 peft==0.4.0 bitsandbytes==0.40.2 trl==0.4.7
pip install -q wandb

Let’s import the dependencies:

import os
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline
)
from datasets import load_dataset
from peft import LoraConfig, PeftModel

Now, get your HuggingFace token in HF site Settings and paste at login. You can also save it in Colab secrets.

from huggingface_hub import notebook_login

notebook_login()

Now we will load quantized Gemma-2b. In the context of PEFT (Parameter-Efficient Fine-Tuning) and LoRA (Low-Rank Adaptation), target_modules is a configuration option within the LoraConfig class. It specifies which modules in the pre-trained model you want to fine-tune using LoRA. A common practice involves fine-tuning linear layers, so you might see “q”, “v” (referring to query and value projections) or “all-linear” specified in target_modules.

The primary goal of the settings presented in the code ahead is to reduce the model’s memory footprint and potentially enhance computational speed by utilizing a lower precision format (4-bit quantization) for model weights and computations.

In LoraConfig, we have some parameters:

  • “r” is the Rank of Decomposition is an integer value that determines the size of the matrices used for low-rank updates. Lower r values result in smaller update matrices, which translates to fewer trainable parameters. This makes the fine-tuning process more efficient and fast.
    However, a very low r might not capture the complexity needed for the adaptation task.
  • lora_alpha is the Alpha Parameter: This scales the update matrices during training. It’s often tuned alongside the learning rate.
  • bias: This controls how bias parameters are treated during training. You can choose to train them, keep them fixed, or only train the LoRA-specific biases.

In BitsAndBytesConfig we have the following:

1. Activating 4-Bit Precision:

use_4bit = True: This line enables loading the base model using a lower precision of 4 bits per parameter, as opposed to the typical 32 bits (float32). This can significantly reduce memory usage and potentially speed up computations.

2. Choosing Compute Data Type:

bnb_4bit_compute_dtype = “float16”: This specifies the data type to be used for model computations (not storage) when using 4-bit precision. Here, it’s set to float16, which is a 16-bit floating-point type that offers a balance between precision and memory efficiency.

3. Specifying Quantization Method:

bnb_4bit_quant_type = “nf4”: This selects the specific quantization technique to be used for converting model weights from 32-bit to 4-bit format. “nf4” likely refers to “narrow float4,” a specific quantization scheme.

4. Enabling Nested Quantization:

use_nested_quant = True: This activates a technique called “nested quantization,” which involves quantizing both the model weights and the activations (output values of layers) during model computations. This can further reduce memory footprint and potentially improve performance.

5. Constructing Quantization Configuration:

The last part creates a BitsAndBytesConfig object, which holds the specified configuration options. This object likely interfaces with quantization libraries or model loading processes to apply the desired quantization settings.


model_name='google/gemma-2b'

lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)

# bitsandbytes parameters
#################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = True

# Set up quantization config
#################################################################
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)

# Load pre-trained config
#################################################################

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"":0})

Now, we are going to get the dataset, orca-math-word-problems-200k, from Microsoft. The dataset is at HuggingFace.

The paper, Orca-Math: Unlocking the potential of SLMs in Grade School Math, is available at arxiv.

from datasets import load_dataset

data = load_dataset("microsoft/orca-math-word-problems-200k")
data = data.map(lambda samples: tokenizer(samples["question"]), batched=True)

Let’s take a look at one example of the dataset:

Finally, let’s format data and get the Trainer prepared: here you may play with the hyperparameters and then the prompts to improve the final result. This code sets up a trainer for fine-tuning a large language model using techniques like gradient accumulation, mixed precision, and LoRA to optimize memory usage and potentially speed up training.

import transformers
from trl import SFTTrainer
from transformers.generation.utils import top_k_top_p_filtering

def formatting_func(example):
output_texts = []
for i in range(len(example)):
text = f"Question: {example['question'][i]}\nAnswer: {example['answer'][i]}"
output_texts.append(text)
return output_texts

trainer = SFTTrainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=500,
learning_rate=2e-4,
fp16=True,
logging_steps=10,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_config,
formatting_func=formatting_func,
)

Let’s check how many trainable parameters do we have: a tiny part.

def print_number_of_trainable_model_parameters(model):
trainable_model_params = 0
all_model_params = 0
for _, param in model.named_parameters():
all_model_params += param.numel()
if param.requires_grad:
trainable_model_params += param.numel()
return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

Get your wandb API key and run:

trainer.train()

That’s it, simple and easy. If you have enough time and computational power, train during more epochs to run all over the dataset to achieve better results. Here, prompt is the key.

Let’s briefly test the fine tuned model:

Question: You subtracted 50 from a number and the result is 43. How much is this number ?

Question: You have a circle with radius 5. How much is its perimeter ?

Question: John and Mary have three kids. The sum of their ages is 22. The first son is 5 years old. The eldest has double the age of the first son. What is the age of the remaining son?

Not bad for a tiny training with 500 examples and a model with only 1.5 billion parameters!

Google ML Developer Programs team supported this work by providing Google Cloud Credits

--

--

Rubens Zimbres
Google Developer Experts

I’m a Senior Data Scientist and Google Developer Expert in ML and GCP. I love studying NLP algos and Cloud Infra. CompTIA Security +. PhD. www.rubenszimbres.phd