unslothai/unsloth

GRPOTrainer fails after SFTTrainer - Gemma-3-1b-it (matmul shape mismatch)

Open

#3,069 opened on Jul 30, 2025

View on GitHub
 (2 comments) (0 reactions) (0 assignees)Python (64,271 stars) (5,658 forks)batch import
help wantedinactiveunsure bug?

Description

Hello! (First of all, thank you for Unsloth, it's been incredibly useful!)

Issue

I am trying to combine SFT fine-tuning with GRPO using unsloth/gemma-3-1b-it (I also tried with the 4bit version and google/gemma-3-1b-it, but still getting the same error), but I keep getting a shape mismatch error when running GRPOTrainer.train() after SFTTrainer. My pipeline consists of:

  1. Load base model with FastLanguageModel.from_pretrained()
  2. Add LoRA with FastLanguageModel.get_peft_model()
  3. Train with trl.SFTTrainer
  4. Train with GRPOTrainer(model=trainer.model, ..)

I tried both running all in the same session and with save/reload, but the issue remains.

The error is:

TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7d014f41ff00>(*(GradTrackingTensor(lvl=1, value= FakeTensor(..., device='cuda:0', size=(1, s1, s2), dtype=torch.float16, requires_grad=True) ), GradTrackingTensor(lvl=1, value= FakeTensor(..., device='cuda:0', size=(1152, 262144), dtype=torch.float16) )), **{}): a and b must have same reduction dim, but got [s1, s2] X [1152, 262144].

from user code: File "/content/drive/MyDrive/OIN_clustering/VM_PIPELINE/scripts/V2/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 234, in accumulate_chunk (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value( File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/apis.py", line 442, in wrapper return eager_transforms.grad_and_value_impl( File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/vmap.py", line 48, in fn return f(*args, **kwargs) File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl output = func(*args, **kwargs) File "/content/drive/MyDrive/OIN_clustering/VM_PIPELINE/scripts/V2/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 182, in compute_loss new_logits = torch.matmul(new_hidden_states, lm_head.t()) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Some checks:

hidden_states[-1].shape → torch.Size([1, 3, 1152]) lm_head.weight.shape → torch.Size([262144, 1152]) logits → torch.Size([1, 3, 262144]) type(model) → peft.peft_model.PeftModelForCausalLM

So I am not sure what is going on or what I am doing wrong (I am also a beginner).

=========

  1. Did you update? pip install --upgrade unsloth unsloth_zoo -- YES
  2. Colab or Kaggle or local / cloud -- COLAB T4 VM
  3. Number GPUs used, use nvidia-smi -- 1
  4. Which notebook? Please link!
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?

torch: 2.6.0+cu124 transformers: 4.54.0 trl: 0.20.0 unsloth: 2025.7.11 peft: 0.16.0

  1. Which trainer? SFTTrainer, GRPOTrainer etc -- [SFTTrainer, GRPOTrainer]

=========

=========


##### Base model ####

import unsloth
from unsloth import FastLanguageModel
from peft import LoraConfig

MAX_SEQ_LEN = 1024
LORA_RANK = 32
BASE = "unsloth/gemma-3-1b-it" 

model, tok = FastLanguageModel.from_pretrained(
    BASE,
    device_map="auto",
    max_seq_length=MAX_SEQ_LEN,
    load_in_4bit=True,
    full_finetuning = False,
    fast_inference=True,
    max_lora_rank=LORA_RANK,
    trust_remote_code       = True,
    #gpu_memory_utilization=0.8
)

tok.pad_token = tok.eos_token

model = FastLanguageModel.get_peft_model(
    model,
    r            = LORA_RANK,
    lora_alpha   = LORA_RANK*2,
    lora_dropout = 0,
    bias         = "none",
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
    use_gradient_checkpointing="unsloth",
    random_state = 3407,
)

##### SFTTrainer (test) #####

from transformers import TrainingArguments
from trl import SFTTrainer
import wandb

BATCH      = 4
GRAD_ACC   = 4
EPOCHS     = 1          
LR         = 2e-5
OUTPUT_DIR = XXXX


args = TrainingArguments(
    output_dir                  = OUTPUT_DIR,
    per_device_train_batch_size = BATCH,
    gradient_accumulation_steps = GRAD_ACC,
    num_train_epochs            = EPOCHS,
    learning_rate               = LR,
    lr_scheduler_type           = "cosine",
    warmup_ratio                = 0.03,
    fp16                        = True,   
    bf16                        = False,
    logging_steps               = 100,
    save_strategy               = "epoch",
    eval_strategy               = "epoch",
    save_total_limit            = 2,
    max_grad_norm               = 1.0,
    optim                       = "adamw_torch_fused",
    report_to                   = "none",
)

trainer = SFTTrainer(
    model           = model,
    args            = args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    tokenizer       = tok,
    packing         = False,
)

trainer.train()

adapter_dir = ADAPTER_DIR
trainer.model.save_pretrained(ADAPTER_DIR)
tok.save_pretrained(ADAPTER_DIR)


#### GRPO Trainer #####

from trl import GRPOConfig, GRPOTrainer

max_seq_length = 1024
max_prompt_length = 600

training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, 
    num_generations = 4,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    max_steps = 2,
    save_steps = 2,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs-v2",
)

# Either don't save/reload anything and continue with grpo_trainer.model as it is:

 grpo_trainer = GRPOTrainer(
    model           = trainer.model, # !!!
    processing_class= tok,
    reward_funcs    = [
        match_format_exactly,
        match_format_approximately,
        check_structured_tags,
        check_nested_tags,
        check_duplicate_or_bad_tags_format,
  ],
    args            = training_args,
    train_dataset   = rl_ds,
)

grpo_trainer.train()

# Or tried: 

new_model, new_tok = FastLanguageModel.from_pretrained(
    model_name         = BASE,  
    max_seq_length     = MAX_SEQ_LEN,
    load_in_4bit       = True,
    full_finetuning    = False, # Tried setting to True as well but same error
    trust_remote_code  = True,
    device_map         = "auto",
)

# Same configs
new_model = FastLanguageModel.get_peft_model(
    new_model,
    r                  = 32,
    lora_alpha         = 64,
    lora_dropout       = 0.05,
    bias               = "none",
    use_gradient_checkpointing = "unsloth",
    target_modules     = [  
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    random_state       = 3407,
)

new_model.load_adapter(ADAPTER_DIR, adapter_name="sft", is_trainable=False) # tried with is_trainable=True as well but same error

assert isinstance(new_model, PeftModel)

grpo_trainer = GRPOTrainer(
    model           = new_model,
    processing_class= new_tok,
    reward_funcs    = [
        match_format_exactly,
        match_format_approximately,
        check_structured_tags,
        check_nested_tags,
        check_duplicate_or_bad_tags_format,
  ],
    args            = training_args,
    train_dataset   = rl_ds,
)

grpo_trainer.train()

Questions

Let me know if there is any other info you might need! Thank you very much!

Contributor guide