unslothai/unsloth

[Bug] Forced coupling between num_generations and per_device_train_batch_size in GRPOTrainer resulting in OOM

Open

#3 572 ouverte le 8 nov. 2025

Voir sur GitHub
 (3 commentaires) (1 réaction) (0 assignés)Python (5 658 forks)batch import
help wanted

Métriques du dépôt

Stars
 (64 271 stars)
Métriques de merge PR
 (Merge moyen 3j 15h) (525 PRs mergées en 30 j)

Description

  1. Did you update? pip install --upgrade unsloth unsloth_zoo. Yes
  2. Colab or Kaggle or local / cloud. Cloud
  3. Number GPUs used, use nvidia-smi. One Gpu.
  4. Which notebook? Please link! https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb
  5. Which Unsloth version, TRL version, transformers version, PyTorch version? Unsloth: 2025.11.2 TRL: 0.22.2 Transformers: 4.56.2 PyTorch: 2.8.0+cu128
  6. Which trainer? SFTTrainer, GRPOTrainer etc. GrpoTrainer

When setting per_device_train_batch_size different from num_generations in GRPOConfig, a warning appears:

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 32.

However, num_generations is a critical parameter for GRPO and convergence — in your demo notebooks, it’s typically some small value. When the trainer automatically adjusts per_device_train_batch_size to match num_generations, this leads to out-of-memory (OOM) errors.

In other words, large num_generations values are necessary for stable training, but the enforced coupling makes GRPOTrainer practically unusable.

I’d like to understand the correct way to use a large num_generations value without running into out-of-memory (OOM) issues.

Note: Related to unslothai/unsloth#3149. In that closed issue, @mmathew23 commented:

“But if it does decrease num_generations to 6 and increase gradient_accumulation_steps to 4, you’ll still get the 12 generations per prompt per optimizer step.”

I don’t quite understand how this results in 12 generations per ONE prompt— there’s no arithmetic relationship between 6 and 4 that gives 12, either by multiplication or division.

Guide contributeur