unslothai/unsloth
View on GitHub[Bug] When use customized trl.trainer, there is a sharp increase in CUDA memory?
Open
#2,397 opened on Apr 23, 2025
help wantedinactive
Description
1.5B + GRPO, the code is:
`class GRPOTrainer_noKL(GRPOTrainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
print(model.config.use_cache)
print(model.dtype)
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# dict_keys(['prompt_ids', 'prompt_mask', 'completion_ids', 'completion_mask', 'ref_per_token_logps', 'advantages'])
# print(inputs.keys())
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# self.beta 默认是0.04,这里我们可以马上设置
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# per_token_loss = -per_token_loss
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
per_token_loss = - per_token_logps * advantages.unsqueeze(1)
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
del inputs
return loss
`
the GPU memory increases to 80G