unslothai/unsloth

[Bug] When use customized trl.trainer, there is a sharp increase in CUDA memory?

Open

#2,397 opened on Apr 23, 2025

View on GitHub
 (2 comments) (0 reactions) (0 assignees)Python (64,271 stars) (5,658 forks)batch import
help wantedinactive

Description

1.5B + GRPO, the code is:

`class GRPOTrainer_noKL(GRPOTrainer):

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    print(model.config.use_cache)
    print(model.dtype)
    if return_outputs:
        raise ValueError("The GRPOTrainer does not support returning outputs")
    # dict_keys(['prompt_ids', 'prompt_mask', 'completion_ids', 'completion_mask', 'ref_per_token_logps', 'advantages'])
    # print(inputs.keys())

    prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
    completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
    logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

    per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

    # Compute the KL divergence between the model and the reference model
    ref_per_token_logps = inputs["ref_per_token_logps"]
    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

    # x - x.detach() allows for preserving gradients from x
    advantages = inputs["advantages"]
    # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
    # self.beta 默认是0.04,这里我们可以马上设置
    # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
    # per_token_loss = -per_token_loss
    # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    per_token_loss = - per_token_logps * advantages.unsqueeze(1)
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    # Log the metrics
    completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
    self._metrics["completion_length"].append(completion_length)

    mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
    del inputs
    return loss

`

the GPU memory increases to 80G

Contributor guide