unslothai/unsloth

[Bug] When use customized trl.trainer, there is a sharp increase in CUDA memory?

Open

#2397 aperta il 23 apr 2025

Vedi su GitHub
 (2 commenti) (0 reazioni) (0 assegnatari)Python (5658 fork)batch import
help wantedinactive

Metriche repository

Star
 (64.271 star)
Metriche merge PR
 (Merge medio 3g 15h) (525 PR mergiate in 30 g)

Descrizione

1.5B + GRPO, the code is:

`class GRPOTrainer_noKL(GRPOTrainer):

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    print(model.config.use_cache)
    print(model.dtype)
    if return_outputs:
        raise ValueError("The GRPOTrainer does not support returning outputs")
    # dict_keys(['prompt_ids', 'prompt_mask', 'completion_ids', 'completion_mask', 'ref_per_token_logps', 'advantages'])
    # print(inputs.keys())

    prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
    completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
    logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

    per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

    # Compute the KL divergence between the model and the reference model
    ref_per_token_logps = inputs["ref_per_token_logps"]
    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

    # x - x.detach() allows for preserving gradients from x
    advantages = inputs["advantages"]
    # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
    # self.beta 默认是0.04,这里我们可以马上设置
    # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
    # per_token_loss = -per_token_loss
    # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    per_token_loss = - per_token_logps * advantages.unsqueeze(1)
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    # Log the metrics
    completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
    self._metrics["completion_length"].append(completion_length)

    mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
    del inputs
    return loss

`

the GPU memory increases to 80G

Guida contributor