reinforce(gspo) training didn't yield any improments
#3485 opened on Oct 20, 2025
Description
- Did you update? yes
- cloud environment
- 1 GPU used
- Name: transformers Version: 4.56.2 Name: trl Version: 0.23.0 Name: unsloth Version: 2025.10.3
GRPOTrainer
hi ,i was trying to implement reinforcing learning (full fintuning) on an information extraction task with 0.6b model without reasoning proecess(12 fields, about 43% accuracy overall before training,), 330 labeled samples. yet no better result come from training,notsure where the problem is.
reward fuction i designed to compare field accuracy: This code implements reward functions for evaluating dimension extraction tasks, with special handling for event names using similarity matching, improved null value handling, and weighted scoring based on field counts.
def calculate_string_similarity(str1: str, str2: str) -> float:
"""Calculate the similarity between two strings"""
if not str1 or not str2:
return 0.0
return difflib.SequenceMatcher(None, str1.lower(), str2.lower()).ratio()
def parse_json_response(response: str) -> dict:
"""Parse the JSON output from the model"""
try:
return json.loads(response.strip())
except json.JSONDecodeError:
# Try to extract JSON portion
import re
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
return {"condition": None, "event_extraction": None}
def dimension_extraction_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""
Improved dimension extraction reward function
- Weighted evaluation by field count to avoid imbalance
- Improved null handling, partial matches receive partial scores
- Adjustable similarity threshold and reward weights
- Total score range 0-3.0 (larger optimization space)
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
for i, response in enumerate(responses):
try:
# Parse model output and ground truth
pred_data = parse_json_response(response)
if isinstance(answer[i], str):
label_data = json.loads(answer[i])
else:
label_data = answer[i]
total_score = 0.0
penalty = 0.0
max_possible_score = 0.0
# Evaluate condition section
pred_condition = pred_data.get("condition")
label_condition = label_data.get("condition")
condition_result = evaluate_section_match_improved(pred_condition, label_condition, "condition")
total_score += condition_result["score"]
penalty += condition_result["penalty"]
max_possible_score += condition_result["max_score"]
# Evaluate event_extraction section
pred_event = pred_data.get("event_extraction")
label_event = label_data.get("event_extraction")
event_result = evaluate_section_match_improved(pred_event, label_event, "event")
total_score += event_result["score"]
penalty += event_result["penalty"]
max_possible_score += event_result["max_score"]
# Final reward: deduct penalty and normalize by max possible score
raw_score = max(0.0, total_score - penalty)
# Normalize to 0-3.0 range, giving the model larger optimization space
if max_possible_score > 0:
normalized_score = (raw_score / max_possible_score) * 3.0
else:
normalized_score = 3.0 if raw_score == 0 and penalty == 0 else 0.0
rewards.append(normalized_score)
# Print detailed information for the first sample
if i == 0:
print('-'*70)
print(f"Query: {prompts[0][-1]['content'][:150]}...")
print(f"Predicted: {pred_data}")
print(f"Label: {label_data}")
print(f"Condition: {condition_result['score']:.3f}/{condition_result['max_score']:.1f}")
print(f"Event: {event_result['score']:.3f}/{event_result['max_score']:.1f}")
print(f"Raw Score: {raw_score:.3f}, Max Possible: {max_possible_score:.1f}")
print(f"Final Score: {normalized_score:.3f} (penalty: {penalty:.3f})")
print('-'*70)
except Exception as e:
print(f"Error processing sample {i}: {e}")
rewards.append(0.0)
return rewards
def evaluate_section_match_improved(pred_section, label_section, section_name: str) -> dict:
"""
Improved section evaluation function
- Allocate weights by field count
- Improved null handling logic
- Adjustable event_name similarity threshold
"""
# Perfect match case
if pred_section is None and label_section is None:
return {"score": 1.0, "max_score": 1.0, "penalty": 0.0}
# One is null, the other is not
if pred_section is None and label_section is not None:
# Should output but didn't, give 0 score
field_count = len(label_section) if isinstance(label_section, dict) else 1
return {"score": 0.0, "max_score": float(field_count), "penalty": 0.0}
if pred_section is not None and label_section is None:
# Shouldn't output but did, slight penalty but not completely 0
field_count = len(pred_section) if isinstance(pred_section, dict) else 1
return {"score": 0.0, "max_score": 1.0, "penalty": field_count * 0.2}
# Type check
if not isinstance(pred_section, dict) or not isinstance(label_section, dict):
return {"score": 0.0, "max_score": 1.0, "penalty": 0.0}
label_fields = set(label_section.keys())
pred_fields = set(pred_section.keys())
# Calculate max possible score by field count
max_score = float(len(label_fields)) if len(label_fields) > 0 else 1.0
# Calculate field matching score
field_score = 0.0
for field in label_fields:
pred_value = pred_section.get(field)
label_value = label_section[field]
field_match = evaluate_field_match_improved(pred_value, label_value, field)
field_score += field_match
# Penalty for extra fields - adjust based on field importance
extra_fields = pred_fields - label_fields
penalty = len(extra_fields) * 0.15 # Slightly increase penalty
return {"score": field_score, "max_score": max_score, "penalty": penalty}
def evaluate_field_match_improved(pred_value, label_value, field_name: str) -> float:
"""
Improved field matching function
- Adjustable event_name similarity threshold and reward curve
- Better list and type handling
"""
if pred_value is None and label_value is None:
return 1.0
if pred_value is None or label_value is None:
return 0.0
# Special handling for event_name: use similarity matching with smoother reward curve
if field_name == "event_name":
similarity = calculate_string_similarity(str(pred_value), str(label_value))
if similarity >= 0.9:
return 1.0 # Perfect score for high similarity
elif similarity >= 0.8:
return 0.8 + (similarity - 0.8) * 2 # Map 0.8-0.9 to 0.8-1.0
elif similarity >= 0.6:
return 0.4 + (similarity - 0.6) * 2 # Map 0.6-0.8 to 0.4-0.8
else:
return 0.0 # 0 score for low similarity
# List field handling - consider partial matches
if isinstance(pred_value, list) and isinstance(label_value, list):
if len(label_value) == 0:
return 1.0 if len(pred_value) == 0 else 0.0
pred_set = set(str(x) for x in pred_value)
label_set = set(str(x) for x in label_value)
# Calculate intersection ratio
intersection = pred_set & label_set
union = pred_set | label_set
if len(union) == 0:
return 1.0
# Jaccard similarity, but emphasize recall
precision = len(intersection) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(intersection) / len(label_set) if len(label_set) > 0 else 0
# Weighted F1 score, emphasizing recall
if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
return 0.3 * precision + 0.7 * recall # Emphasize recall
else:
return 0.0
# Scalar comparison
return 1.0 if str(pred_value).strip() == str(label_value).strip() else 0.0
def json_format_reward_func(completions, **kwargs) -> list[float]:
"""
Simplified JSON format reward function (avoid duplication with main reward function)
- Only check basic JSON validity (0.5 points)
- Check top-level structure reasonableness (0.5 points)
- Total score 0-1.0, as a supplement to main reward function
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
for response in responses:
score = 0.0
try:
# Basic JSON parsing
parsed = json.loads(response.strip())
score += 0.5 # Basic JSON validity
# Check top-level structure (should be dict and only contain expected fields)
if isinstance(parsed, dict):
expected_fields = {'condition', 'event_extraction'}
actual_fields = set(parsed.keys())
# Reward for containing expected fields
if expected_fields.issubset(actual_fields):
score += 0.3
# Slight penalty for extra top-level fields
extra_top_fields = actual_fields - expected_fields
if len(extra_top_fields) == 0:
score += 0.2
else:
score += max(0.0, 0.2 - len(extra_top_fields) * 0.1)
except json.JSONDecodeError:
# Try partial extraction as last resort
import re
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
try:
json.loads(json_match.group())
score = 0.2 # Partial extraction successful
except json.JSONDecodeError:
score = 0.0
rewards.append(score)
return rewards
training_args = GRPOConfig( # vllm_sampling_params = vllm_sampling_params, temperature = 1.0, learning_rate = 5e-6, weight_decay = 0.01, warmup_ratio = 0.1, bf16 = is_bfloat16_supported(), fp16 = not is_bfloat16_supported(), lr_scheduler_type = "linear", optim = "adamw_8bit", logging_steps = 1, per_device_train_batch_size = 8, gradient_accumulation_steps = 2, # Increase to 4 for smoother training num_generations = 8, # Decrease if out of memory # max_prompt_length = max_prompt_length, # max_completion_length = max_completion_length, num_train_epochs = 10, # Set to 1 for a full training run # max_steps = 1, save_steps = 50, report_to = "tensorboard", output_dir = "outputs", max_prompt_length = 4096, max_completion_length = 512, logging_dir=f"./fine_tuning/query_parser_0.6/logs/", # epsilon = 3e-4, # epsilon_high = 4e-4, epsilon = 0.2, epsilon_high = 0.28,
# For optional training + evaluation
fp16_full_eval = True,
per_device_eval_batch_size = 16,
# eval_accumulation_steps = 1,
eval_strategy = "steps",
eval_steps = 3,
beta=0,
# GSPO is below:
importance_sampling_level = "sequence",
# Dr GRPO / GAPO etc
# loss_type = "dr_grpo",
loss_type = "dapo",
mask_truncated_completions = True,
)
training result: {'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 3.0303030303030305e-08, 'num_tokens': 83968.0, 'completions/mean_length': 77.5, 'completions/min_length': 74.0, 'completions/max_length': 81.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 77.5, 'completions/min_terminated_length': 74.0, 'completions/max_terminated_length': 81.0, 'rewards/dimension_extraction_reward_func/mean': 3.0, 'rewards/dimension_extraction_reward_func/std': 0.0, 'rewards/json_format_reward_func/mean': 1.0, 'rewards/json_format_reward_func/std': 0.0, 'reward': 4.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 81.0, 'kl': 0.0, 'epoch': 0.01}
[A{'eval_loss': 3.451258479003627e-08, 'eval_runtime': 62.4148, 'eval_samples_per_second': 1.362, 'eval_steps_per_second': 0.096, 'num_tokens': 1001646.0, 'completions/mean_length': 40.13690476190476, 'completions/min_length': 25.476190476190474, 'completions/max_length': 58.666666666666664, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 40.13690476190476, 'completions/min_terminated_length': 25.476190476190474, 'completions/max_terminated_length': 58.666666666666664, 'rewards/dimension_extraction_reward_func/mean': 2.3732229755038308, 'rewards/dimension_extraction_reward_func/std': 0.6416230095284325, 'rewards/json_format_reward_func/mean': 0.9761904761904762, 'rewards/json_format_reward_func/std': 0.05239650039445786, 'reward': 3.349413451694307, 'reward_std': 0.16261665984278634, 'frac_reward_zero_std': 0.6190476190476191, 'epoch': 0.02}
[{'eval_loss': -5.441014749862916e-08, 'eval_runtime': 58.0178, 'eval_samples_per_second': 1.465, 'eval_steps_per_second': 0.103, 'num_tokens': 550893640.0, 'completions/mean_length': 41.18154761904762, 'completions/min_length': 25.904761904761905, 'completions/max_length': 59.80952380952381, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 41.18154761904762, 'completions/min_terminated_length': 25.904761904761905, 'completions/max_terminated_length': 59.80952380952381, 'rewards/dimension_extraction_reward_func/mean': 2.3789829867226735, 'rewards/dimension_extraction_reward_func/std': 0.617440711529482, 'rewards/json_format_reward_func/mean': 0.9791666666666666, 'rewards/json_format_reward_func/std': 0.04946565060388474, 'reward': 3.3581496533893405, 'reward_std': 0.14253070666676476, 'frac_reward_zero_std': 0.6428571428571429, 'epoch': 10.0}