unslothai/unsloth

reinforce(gspo) training didn't yield any improments

Open

#3485 opened on Oct 20, 2025

View on GitHub
 (1 comment) (0 reactions) (0 assignees)Python (64,271 stars) (5,658 forks)batch import
help wantedunsure bug?

Description

  1. Did you update? yes
  2. cloud environment
  3. 1 GPU used
  4. Name: transformers Version: 4.56.2 Name: trl Version: 0.23.0 Name: unsloth Version: 2025.10.3
  5. GRPOTrainer

hi ,i was trying to implement reinforcing learning (full fintuning) on an information extraction task with 0.6b model without reasoning proecess(12 fields, about 43% accuracy overall before training,), 330 labeled samples. yet no better result come from training,notsure where the problem is.

reward fuction i designed to compare field accuracy: This code implements reward functions for evaluating dimension extraction tasks, with special handling for event names using similarity matching, improved null value handling, and weighted scoring based on field counts.

def calculate_string_similarity(str1: str, str2: str) -> float:
    """Calculate the similarity between two strings"""
    if not str1 or not str2:
        return 0.0
    return difflib.SequenceMatcher(None, str1.lower(), str2.lower()).ratio()



def parse_json_response(response: str) -> dict:
    """Parse the JSON output from the model"""
    try:
        return json.loads(response.strip())
    except json.JSONDecodeError:
        # Try to extract JSON portion
        import re
        json_match = re.search(r'\{.*\}', response, re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group())
            except json.JSONDecodeError:
                pass
        return {"condition": None, "event_extraction": None}

def dimension_extraction_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Improved dimension extraction reward function
    - Weighted evaluation by field count to avoid imbalance
    - Improved null handling, partial matches receive partial scores
    - Adjustable similarity threshold and reward weights
    - Total score range 0-3.0 (larger optimization space)
    """
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    
    for i, response in enumerate(responses):
        try:
            # Parse model output and ground truth
            pred_data = parse_json_response(response)
            if isinstance(answer[i], str):
                label_data = json.loads(answer[i])
            else:
                label_data = answer[i]
            
            total_score = 0.0
            penalty = 0.0
            max_possible_score = 0.0
            
            # Evaluate condition section
            pred_condition = pred_data.get("condition")
            label_condition = label_data.get("condition")
            condition_result = evaluate_section_match_improved(pred_condition, label_condition, "condition")
            total_score += condition_result["score"]
            penalty += condition_result["penalty"]
            max_possible_score += condition_result["max_score"]
            
            # Evaluate event_extraction section
            pred_event = pred_data.get("event_extraction")
            label_event = label_data.get("event_extraction")
            event_result = evaluate_section_match_improved(pred_event, label_event, "event")
            total_score += event_result["score"]
            penalty += event_result["penalty"]
            max_possible_score += event_result["max_score"]
            
            # Final reward: deduct penalty and normalize by max possible score
            raw_score = max(0.0, total_score - penalty)
            # Normalize to 0-3.0 range, giving the model larger optimization space
            if max_possible_score > 0:
                normalized_score = (raw_score / max_possible_score) * 3.0
            else:
                normalized_score = 3.0 if raw_score == 0 and penalty == 0 else 0.0
            
            rewards.append(normalized_score)
            
            # Print detailed information for the first sample
            if i == 0:
                print('-'*70)
                print(f"Query: {prompts[0][-1]['content'][:150]}...")
                print(f"Predicted: {pred_data}")
                print(f"Label: {label_data}")
                print(f"Condition: {condition_result['score']:.3f}/{condition_result['max_score']:.1f}")
                print(f"Event: {event_result['score']:.3f}/{event_result['max_score']:.1f}")
                print(f"Raw Score: {raw_score:.3f}, Max Possible: {max_possible_score:.1f}")
                print(f"Final Score: {normalized_score:.3f} (penalty: {penalty:.3f})")
                print('-'*70)
        
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            rewards.append(0.0)
    
    return rewards

def evaluate_section_match_improved(pred_section, label_section, section_name: str) -> dict:
    """
    Improved section evaluation function
    - Allocate weights by field count
    - Improved null handling logic
    - Adjustable event_name similarity threshold
    """
    # Perfect match case
    if pred_section is None and label_section is None:
        return {"score": 1.0, "max_score": 1.0, "penalty": 0.0}
    
    # One is null, the other is not
    if pred_section is None and label_section is not None:
        # Should output but didn't, give 0 score
        field_count = len(label_section) if isinstance(label_section, dict) else 1
        return {"score": 0.0, "max_score": float(field_count), "penalty": 0.0}
    
    if pred_section is not None and label_section is None:
        # Shouldn't output but did, slight penalty but not completely 0
        field_count = len(pred_section) if isinstance(pred_section, dict) else 1
        return {"score": 0.0, "max_score": 1.0, "penalty": field_count * 0.2}
    
    # Type check
    if not isinstance(pred_section, dict) or not isinstance(label_section, dict):
        return {"score": 0.0, "max_score": 1.0, "penalty": 0.0}
    
    label_fields = set(label_section.keys())
    pred_fields = set(pred_section.keys())
    
    # Calculate max possible score by field count
    max_score = float(len(label_fields)) if len(label_fields) > 0 else 1.0
    
    # Calculate field matching score
    field_score = 0.0
    for field in label_fields:
        pred_value = pred_section.get(field)
        label_value = label_section[field]
        field_match = evaluate_field_match_improved(pred_value, label_value, field)
        field_score += field_match
    
    # Penalty for extra fields - adjust based on field importance
    extra_fields = pred_fields - label_fields
    penalty = len(extra_fields) * 0.15  # Slightly increase penalty
    
    return {"score": field_score, "max_score": max_score, "penalty": penalty}

def evaluate_field_match_improved(pred_value, label_value, field_name: str) -> float:
    """
    Improved field matching function
    - Adjustable event_name similarity threshold and reward curve
    - Better list and type handling
    """
    if pred_value is None and label_value is None:
        return 1.0
    
    if pred_value is None or label_value is None:
        return 0.0
    
    # Special handling for event_name: use similarity matching with smoother reward curve
    if field_name == "event_name":
        similarity = calculate_string_similarity(str(pred_value), str(label_value))
        if similarity >= 0.9:
            return 1.0  # Perfect score for high similarity
        elif similarity >= 0.8:
            return 0.8 + (similarity - 0.8) * 2  # Map 0.8-0.9 to 0.8-1.0
        elif similarity >= 0.6:
            return 0.4 + (similarity - 0.6) * 2  # Map 0.6-0.8 to 0.4-0.8
        else:
            return 0.0  # 0 score for low similarity
    
    # List field handling - consider partial matches
    if isinstance(pred_value, list) and isinstance(label_value, list):
        if len(label_value) == 0:
            return 1.0 if len(pred_value) == 0 else 0.0
        
        pred_set = set(str(x) for x in pred_value)
        label_set = set(str(x) for x in label_value)
        
        # Calculate intersection ratio
        intersection = pred_set & label_set
        union = pred_set | label_set
        
        if len(union) == 0:
            return 1.0
        
        # Jaccard similarity, but emphasize recall
        precision = len(intersection) / len(pred_set) if len(pred_set) > 0 else 0
        recall = len(intersection) / len(label_set) if len(label_set) > 0 else 0
        
        # Weighted F1 score, emphasizing recall
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
            return 0.3 * precision + 0.7 * recall  # Emphasize recall
        else:
            return 0.0
    
    # Scalar comparison
    return 1.0 if str(pred_value).strip() == str(label_value).strip() else 0.0

def json_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Simplified JSON format reward function (avoid duplication with main reward function)
    - Only check basic JSON validity (0.5 points)
    - Check top-level structure reasonableness (0.5 points)
    - Total score 0-1.0, as a supplement to main reward function
    """
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    
    for response in responses:
        score = 0.0
        try:
            # Basic JSON parsing
            parsed = json.loads(response.strip())
            score += 0.5  # Basic JSON validity
            
            # Check top-level structure (should be dict and only contain expected fields)
            if isinstance(parsed, dict):
                expected_fields = {'condition', 'event_extraction'}
                actual_fields = set(parsed.keys())
                
                # Reward for containing expected fields
                if expected_fields.issubset(actual_fields):
                    score += 0.3
                
                # Slight penalty for extra top-level fields
                extra_top_fields = actual_fields - expected_fields
                if len(extra_top_fields) == 0:
                    score += 0.2
                else:
                    score += max(0.0, 0.2 - len(extra_top_fields) * 0.1)
                    
        except json.JSONDecodeError:
            # Try partial extraction as last resort
            import re
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                try:
                    json.loads(json_match.group())
                    score = 0.2  # Partial extraction successful
                except json.JSONDecodeError:
                    score = 0.0
        
        rewards.append(score)
    
    return rewards

training_args = GRPOConfig( # vllm_sampling_params = vllm_sampling_params, temperature = 1.0, learning_rate = 5e-6, weight_decay = 0.01, warmup_ratio = 0.1, bf16 = is_bfloat16_supported(), fp16 = not is_bfloat16_supported(), lr_scheduler_type = "linear", optim = "adamw_8bit", logging_steps = 1, per_device_train_batch_size = 8, gradient_accumulation_steps = 2, # Increase to 4 for smoother training num_generations = 8, # Decrease if out of memory # max_prompt_length = max_prompt_length, # max_completion_length = max_completion_length, num_train_epochs = 10, # Set to 1 for a full training run # max_steps = 1, save_steps = 50, report_to = "tensorboard", output_dir = "outputs", max_prompt_length = 4096, max_completion_length = 512, logging_dir=f"./fine_tuning/query_parser_0.6/logs/", # epsilon = 3e-4, # epsilon_high = 4e-4, epsilon = 0.2, epsilon_high = 0.28,

# For optional training + evaluation
fp16_full_eval = True,
per_device_eval_batch_size = 16,
# eval_accumulation_steps = 1,
eval_strategy = "steps",
eval_steps = 3,
beta=0,

# GSPO is below:
importance_sampling_level = "sequence",
# Dr GRPO / GAPO etc
# loss_type = "dr_grpo",
loss_type = "dapo",
mask_truncated_completions = True,

)

training result: {'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 3.0303030303030305e-08, 'num_tokens': 83968.0, 'completions/mean_length': 77.5, 'completions/min_length': 74.0, 'completions/max_length': 81.0, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 77.5, 'completions/min_terminated_length': 74.0, 'completions/max_terminated_length': 81.0, 'rewards/dimension_extraction_reward_func/mean': 3.0, 'rewards/dimension_extraction_reward_func/std': 0.0, 'rewards/json_format_reward_func/mean': 1.0, 'rewards/json_format_reward_func/std': 0.0, 'reward': 4.0, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'completion_length': 81.0, 'kl': 0.0, 'epoch': 0.01}

{'eval_loss': 3.451258479003627e-08, 'eval_runtime': 62.4148, 'eval_samples_per_second': 1.362, 'eval_steps_per_second': 0.096, 'num_tokens': 1001646.0, 'completions/mean_length': 40.13690476190476, 'completions/min_length': 25.476190476190474, 'completions/max_length': 58.666666666666664, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 40.13690476190476, 'completions/min_terminated_length': 25.476190476190474, 'completions/max_terminated_length': 58.666666666666664, 'rewards/dimension_extraction_reward_func/mean': 2.3732229755038308, 'rewards/dimension_extraction_reward_func/std': 0.6416230095284325, 'rewards/json_format_reward_func/mean': 0.9761904761904762, 'rewards/json_format_reward_func/std': 0.05239650039445786, 'reward': 3.349413451694307, 'reward_std': 0.16261665984278634, 'frac_reward_zero_std': 0.6190476190476191, 'epoch': 0.02}

[{'eval_loss': -5.441014749862916e-08, 'eval_runtime': 58.0178, 'eval_samples_per_second': 1.465, 'eval_steps_per_second': 0.103, 'num_tokens': 550893640.0, 'completions/mean_length': 41.18154761904762, 'completions/min_length': 25.904761904761905, 'completions/max_length': 59.80952380952381, 'completions/clipped_ratio': 0.0, 'completions/mean_terminated_length': 41.18154761904762, 'completions/min_terminated_length': 25.904761904761905, 'completions/max_terminated_length': 59.80952380952381, 'rewards/dimension_extraction_reward_func/mean': 2.3789829867226735, 'rewards/dimension_extraction_reward_func/std': 0.617440711529482, 'rewards/json_format_reward_func/mean': 0.9791666666666666, 'rewards/json_format_reward_func/std': 0.04946565060388474, 'reward': 3.3581496533893405, 'reward_std': 0.14253070666676476, 'frac_reward_zero_std': 0.6428571428571429, 'epoch': 10.0}

Contributor guide