# -*- coding: utf-8 -*- """ 完整示例代码:在原有基础上,演示如何在训练过...
बनाया गया: 17 मार्च 2025
बनाया गया: 17 मार्च 2025
"""
完整示例代码:在原有基础上,演示如何在训练过程中动态调整/组合各个奖励函数,
包括对“恐惧/未知”等特定表达进行惩罚或奖励的逻辑。
此示例中使用了一个“dynamic_reward_func”函数,
它会根据任务类型(math 或 lyrics)和其他自定义需求(如“fear_of_unknown”检测)
动态调用不同的子奖励函数并进行组合。可在实际使用中根据需求自由调整权重、逻辑等。
"""
import torch
import re
import random
import langid
from typing import List
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk import edit_distance
from datasets import load_dataset, Dataset, concatenate_datasets
from trl import GRPOConfig, GRPOTrainer
import wandb
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from unsloth import FastLanguageModel, PatchFastRL
from unsloth import is_bfloat16_supported
SUPPORTED_LANGUAGES = {
"en_to_zh": ("英语", "中文"),
"zh_to_en": ("中文", "英语"),
"jp_to_zh": ("日语", "中文"),
"zh_to_azh": ("中文", "中文(文言文)"),
"en_to_azh": ("英语", "中文(文言文)"),
"de_to_zh": ("德语", "中文"),
"italy_to_zh": ("意大利语", "中文"),
"rus_to_zh": ("俄语", "中文"),
"fr_to_zh": ("法语", "中文"),
"mix": ("中文", "multilingual")
}
SYSTEM_PROMPT = """
You are a versatile assistant. Follow these rules:
XML_COT_FORMAT = """
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def get_lyric_datasets(path: str, train_test_ratio=0.8) -> dict:
"""
Load a lyrics translation dataset from JSON and split into train and test sets (e.g. 80%/20%).
Each data item has a 'type' (e.g. 'en_to_zh'), 'lyric', and 'target_lyric'.
Returns a DatasetDict with 'train' and 'test' splits.
"""
data = Dataset.from_json(path)
textdef map_fn(x): lang_src = SUPPORTED_LANGUAGES[x['type']][0] lang_tgt = SUPPORTED_LANGUAGES[x['type']][1] # Insert the correct source->target languages into the system prompt: system_plus = SYSTEM_PROMPT.format(lang_src, lang_tgt) # Mark the "task" field as "lyrics", so we know it is a lyric translation request return { 'prompt': [ {'role': 'system', 'content': system_plus}, {'role': 'user', 'content': x['lyric']} ], 'answer': x['target_lyric'], 'task': "lyrics" } data = data.map(map_fn) return data.train_test_split(train_size=train_test_ratio, seed=42)
def extract_hash_answer(text: str):
"""
Utility for GSM8K data: The official GSM8K solutions often contain '#### <numerical answer>'.
This function extracts the substring after '####' as the final numeric/string answer.
Return None if not found.
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_gsm8k_questions(split="train", sample_size=3000) -> Dataset:
"""
Example: load and unify multiple GSM8K parquet shards into one dataset (just a demonstration).
We'll just show how you might do so. Replace it with your real data loading path or method.
"""
# 以下路径仅为演示,请自行替换为真实路径
data = Dataset.from_parquet(f'../data_pack/openai/gsm8k/{split}-*.parquet') # type: ignore
data = data.select(range(sample_size))
textdef map_fn(x): user_prompt = x['question'] final_answer = extract_hash_answer(x['answer']) return { 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT.format("N/A", "N/A")}, {'role': 'user', 'content': user_prompt} ], 'answer': final_answer if final_answer else "", 'task': "math" } data = data.map(map_fn) return data
def extract_xml_answer(text: str) -> str:
"""
Extract the text between <answer> and </answer>.
Returns an empty string if not found.
"""
pattern = r"<answer>\s*(.?)\s</answer>"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return ""
def detect_language(text: str) -> str:
"""Use langid to detect language code, e.g. 'en', 'zh', 'fr'."""
return langid.classify(text)[0]
def compute_length_acc(
preds: List[str],
refs: List[str],
tokenizer,
max_tolerance: float = 0.5
) -> List[float]:
"""
返回一组与 refs 同长度的分数列表,每个分数表示对应 pred 与 ref 的长度相似度奖励。
如果 length_ratio <= 0.1 => +1
length_ratio <= 0.2 => +0.7
length_ratio <= 0.3 => +0.5
length_ratio <= 0.4 => +0.3
else => 0
"""
rewards = []
for pred, ref in zip(preds, refs):
pred_tokens = tokenizer.tokenize(pred)
ref_tokens = tokenizer.tokenize(ref)
if len(ref_tokens) == 0:
rewards.append(0.0)
continue
length_ratio = abs(len(pred_tokens) - len(ref_tokens)) / len(ref_tokens)
textif length_ratio <= 0.1: score = 1.0 elif length_ratio <= 0.2: score = 0.7 elif length_ratio <= 0.3: score = 0.5 elif length_ratio <= 0.4: score = 0.3 else: score = 0.0 rewards.append(score) return rewards
def compute_bleu(
preds: List[str],
refs: List[str],
tokenizer
) -> List[float]:
"""
分别对每个 pred/ref pair 计算 BLEU 分数 (sentence-level)。
这里使用4-gram平均权重 (0.25, 0.25, 0.25, 0.25)。
"""
smoothie = SmoothingFunction().method1
weights = (0.25, 0.25, 0.25, 0.25)
scores = []
for pred, ref in zip(preds, refs):
pred_tokens = tokenizer.tokenize(pred)
ref_tokens = tokenizer.tokenize(ref)
if not pred_tokens or not ref_tokens:
scores.append(0.0)
continue
bleu = sentence_bleu(
[ref_tokens],
pred_tokens,
weights=weights,
smoothing_function=smoothie
)
scores.append(bleu)
return scores
def compute_ter(
preds: List[str],
refs: List[str],
tokenizer
) -> List[float]:
"""
逐样本计算 TER (Translation Edit Rate):
TER = (编辑距离 / 参考的词数) * 100
"""
ter_scores = []
for pred, ref in zip(preds, refs):
pred_tokens = tokenizer.tokenize(pred)
ref_tokens = tokenizer.tokenize(ref)
if len(ref_tokens) == 0:
if len(pred_tokens) > 0:
ter_scores.append(100.0)
else:
ter_scores.append(0.0)
continue
dist = edit_distance(pred_tokens, ref_tokens)
ter = (dist / len(ref_tokens)) * 100
ter_scores.append(ter)
return ter_scores
def strict_format_reward_func(completions, *kwargs) -> List[float]:
"""
如果完整匹配:
<reasoning>...</reasoning>
<answer>...</answer>
则奖励+1,否则0
"""
pattern = r"<reasoning>[\s\S]+</reasoning>\s<answer>[\s\S]+</answer>"
responses = [completion[0]["content"] for completion in completions]
scores = []
for r in responses:
if re.search(pattern, r):
scores.append(1.0)
else:
scores.append(0.0)
return scores
def soft_format_reward_func(completions, **kwargs) -> List[float]:
"""
如果包含最基本的 XML 标签 <reasoning>...</reasoning> <answer>...</answer> 就给0.5,否则0
"""
pattern = r"<reasoning>.?</reasoning>\s<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.search(pattern, r, flags=re.DOTALL) for r in responses]
return [0.5 if m else 0.0 for m in matches]
def xmlcount_reward_func(completions, **kwargs) -> List[float]:
"""
简单演示:根据出现的tag次数给分。示例规则:
- <reasoning>\n + \n</reasoning>\n => 0.25
- <answer>\n + \n</answer> => 0.25
- 任何标签之后的多余字符每个扣 0.001 分
"""
def count_xml(text) -> float:
count = 0.0
# <reasoning> 行数判断
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
leftover = text.split("\n</reasoning>\n")[-1]
count -= len(leftover)*0.001
text# <answer> 行数判断 if text.count("\n<answer>\n") == 1: count += 0.125 if text.count("\n</answer>") == 1: count += 0.125 leftover = text.split("\n</answer>")[-1] # 对多余字符扣分 count -= (len(leftover) - 1)*0.001 return count responses = [c[0]["content"] for c in completions] return [count_xml(c) for c in responses]
def correctness_reward_func(prompts, completions, answers, **kwargs) -> List[float]:
"""
对于 math 任务:如果提取到的 final answer 与正确答案完全匹配 => +2.0,否则0.0
"""
responses = [c[0]['content'] for c in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [2.0 if r.strip() == a.strip() else 0.0 for r, a in zip(extracted_responses, answers)]
def int_reward_func(completions, **kwargs) -> List[float]:
"""
对于 math 任务:如果最终答案是纯数字(可带符号),则给+0.5
"""
responses = [c[0]['content'] for c in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
scores = []
for r in extracted_responses:
if re.match(r"^-?\d+$", r):
scores.append(0.5)
else:
scores.append(0.0)
return scores
def length_acc_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]:
"""
检查生成与参考在长度上的相似度
"""
responses = [completion[0]['content'] for completion in completions]
extracted = [extract_xml_answer(r) for r in responses]
return compute_length_acc(preds=extracted, refs=answers, tokenizer=tokenizer)
def bleu_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]:
"""
计算 BLEU 分数并映射到 0~5 的区间上
"""
responses = [c[0]["content"] for c in completions]
extracted = [extract_xml_answer(r) for r in responses]
bleu_scores = compute_bleu(preds=extracted, refs=answers, tokenizer=tokenizer)
rewards = []
for s in bleu_scores:
if s >= 0.9:
rewards.append(5.0)
elif s >= 0.8:
rewards.append(4.5)
elif s >= 0.7:
rewards.append(4.0)
elif s >= 0.6:
rewards.append(3.5)
elif s >= 0.5:
rewards.append(2.5)
elif s >= 0.4:
rewards.append(2.0)
elif s >= 0.3:
rewards.append(1.5)
elif s >= 0.2:
rewards.append(1.0)
elif s >= 0.1:
rewards.append(0.5)
else:
rewards.append(0.0)
return rewards
def ter_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]:
"""
根据 TER (Translation Edit Rate) 分数映射奖励:
if ter >= 80 => 0
elif ter >= 60 => 0.5
elif ter >= 40 => 1.0
elif ter >= 20 => 1.5
else => 2.0
"""
responses = [c[0]["content"] for c in completions]
extracted = [extract_xml_answer(r) for r in responses]
ter_scores = compute_ter(preds=extracted, refs=answers, tokenizer=tokenizer)
rewards = []
for t in ter_scores:
if t >= 80:
rewards.append(0.0)
elif t >= 60:
rewards.append(0.5)
elif t >= 40:
rewards.append(1.0)
elif t >= 20:
rewards.append(1.5)
else:
rewards.append(2.0)
return rewards
def language_recognition(prompts, completions, answers, **kwargs) -> List[float]:
"""
简单判断completions中的语言是否与answers语言一致,如果一致奖励+1,否则0。
"""
responses = [c[0]["content"] for c in completions]
extracted = [extract_xml_answer(r) for r in responses]
rewards = []
for pred, ref in zip(extracted, answers):
if not pred.strip():
rewards.append(0.0)
continue
pred_lang = detect_language(pred)
ref_lang = detect_language(ref)
rewards.append(1.0 if pred_lang == ref_lang else 0.0)
return rewards
def reduce_repetition_reward_func(prompts, completions, answers, **kwargs) -> List[float]:
"""
如果歌词行重复率过高则惩罚;重复率较低则给奖励。
"""
responses = [c[0]["content"] for c in completions]
rewards = []
for resp in responses:
answer_text = extract_xml_answer(resp)
lines = [l.strip() for l in answer_text.split("\n") if l.strip()]
if len(lines) <= 1:
rewards.append(0.0)
continue
unique_lines = len(set(lines))
ratio = unique_lines / len(lines)
# 这里简单规定:如果 80% 以上行是不重复的,就+1;否则0
reward_val = 1.0 if ratio >= 0.8 else 0.0
rewards.append(reward_val)
return rewards
def fear_of_unknown_reward_func(prompts, completions, answers, **kwargs) -> List[float]:
"""
如果输出中出现“不知道”、“???”、“不确定”等表达,则给负奖励(惩罚)。
"""
responses = [c[0]["content"] for c in completions]
triggers = ["i don't know", "i do not know", "not sure", "???", "no idea", "不确定", "不知道"]
scores = []
for resp in responses:
lower_resp = resp.lower()
penalty = 0.0
for t in triggers:
if t in lower_resp:
penalty -= 1.0 # 每出现一个触发词,扣1分
scores.append(penalty)
return scores
def contradiction_penalty_func(prompts, completions, answers, **kwargs) -> List[float]:
"""
如果同一段回答中同时出现相互矛盾的表述 (如 "Yes" 和 "No"),则惩罚。
"""
contradictory_pairs = [
("yes", "no"),
("true", "false"),
("can", "cannot")
]
responses = [c[0]["content"].lower() for c in completions]
scores = []
for r in responses:
penalty = 0.0
for pair in contradictory_pairs:
if pair[0] in r and pair[1] in r:
penalty -= 2.0
scores.append(penalty)
return scores
def curiosity_question_reward_func(prompts, completions, answers, **kwargs) -> List[float]:
"""
如果模型在 <reasoning> 里包含了对自身问题的反问或好奇心表达 (如带 '?' 句子),则给正向奖励。
"""
responses = [c[0]["content"] for c in completions]
pattern = r"<reasoning>(.*?)</reasoning>"
scores = []
for resp in responses:
match = re.search(pattern, resp, flags=re.DOTALL | re.IGNORECASE)
if not match:
scores.append(0.0)
continue
reasoning_text = match.group(1)
question_count = reasoning_text.count('?')
reward = min(question_count * 0.2, 1.0) # 每出现1个问号+0.2,最多加到1.0
scores.append(reward)
return scores
def novelty_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]:
"""
计算模型输出中“新”词的比例(相对于用户 prompt),鼓励一定程度的新意。
"""
user_texts = []
for p in prompts:
user_text = p[-1]['content'] # 最后一条role=user
user_texts.append(user_text)
textresponses = [c[0]["content"] for c in completions] scores = [] for user_t, resp in zip(user_texts, responses): user_tokens = set(tokenizer.tokenize(user_t.lower())) model_answer_text = extract_xml_answer(resp).lower() model_tokens = tokenizer.tokenize(model_answer_text) if not model_tokens: scores.append(0.0) continue new_words = [w for w in model_tokens if w not in user_tokens] ratio = len(new_words) / len(model_tokens) if ratio > 0.8: scores.append(1.0) elif ratio > 0.5: scores.append(0.5) elif ratio > 0.2: scores.append(0.2) else: scores.append(0.0) return scores
def dynamic_reward_func(
prompts,
completions,
answers,
tokenizer=None,
# metadata列表(可选,若 Trainer 允许传递整个batch的原始数据,可从中提取任务task等信息)
batch_metadata=None,
**kwargs
) -> List[float]:
"""
演示如何基于任务类型、恐惧检测结果等动态调用或组合各子奖励函数。
需要保证返回与 batch_size 数量对应的一组分数。
text说明: 1. 如果 batch_metadata 不可用(即 Trainer 不传递),可从 prompts 中自己解析/判断任务类型。 2. 这里示例:先根据任务类型(math 或 lyrics)调用对应子奖励函数,再统一叠加“恐惧/好奇/矛盾”等函数。 3. 对于同一个batch中多条样本,若任务类型不同,我们需要逐条区分计算并叠加分数。 4. 这里只是演示了一个可能的实现,实际可根据需求调整各项子函数的权重、加总方式等。 """ # 首先获取 batch_size batch_size = len(completions) # 准备空列表,存放最终 reward final_rewards = [0.0] * batch_size # 准备容器:先行收集每个子函数的得分,以便做加权/融合 # 例如: { # 'correctness': [...], # 'int_only': [...], # 'bleu': [...], # ... # } collected_scores = {} # 辅助函数:将某个单体奖励函数的输出合并到collected_scores def collect_score(key, scores): collected_scores[key] = scores return scores # 尝试从 prompts / batch_metadata 推断 task # 如果在 Trainer 里没有传递 batch_metadata,可从 example 的 prompt[0] 或 prompt[-1] 找自定义标记 # 这里简化假设:如果 batch_metadata 存在,就直接用 batch_metadata[i]['task'] tasks = [] if batch_metadata is not None: tasks = [meta.get('task', None) for meta in batch_metadata] else: # 如果不可用,就简单从 prompts 推断 # 这里演示:如果 user 提示里含 "math" 则判定为 math,否则当 lyrics for i in range(batch_size): user_text = prompts[i][-1]['content'].lower() if "math" in user_text or "question" in user_text: tasks.append("math") else: tasks.append("lyrics") # 分别将属于 math 的样本索引和 lyrics 的样本索引分开 math_indices = [i for i, t in enumerate(tasks) if t == "math"] lyrics_indices = [i for i, t in enumerate(tasks) if t == "lyrics"] # 计算 math 相关的奖励函数 # 只有对 math_indices 计算,其余位置填0 if math_indices: # 先挑出子batch sub_prompts = [prompts[i] for i in math_indices] sub_completions = [completions[i] for i in math_indices] sub_answers = [answers[i] for i in math_indices] correctness_scores = correctness_reward_func(sub_prompts, sub_completions, sub_answers) int_scores = int_reward_func(sub_completions) strict_scores = strict_format_reward_func(sub_completions) soft_scores = soft_format_reward_func(sub_completions) # 也可以再组合别的,比如 xmlcount_reward_func(sub_completions) # 组装: # 这里示例先简单将这几项加起来 sub_final = [ c + i + s + s2 for c, i, s, s2 in zip(correctness_scores, int_scores, strict_scores, soft_scores) ] # 写回 final_rewards for idx, score_val in zip(math_indices, sub_final): final_rewards[idx] += score_val # 亦可收集 collect_score("math_correctness", correctness_scores) collect_score("math_int", int_scores) collect_score("math_strict", strict_scores) collect_score("math_soft", soft_scores) # 计算 lyrics 相关的奖励函数 if lyrics_indices: sub_prompts = [prompts[i] for i in lyrics_indices] sub_completions = [completions[i] for i in lyrics_indices] sub_answers = [answers[i] for i in lyrics_indices] # 下面几个需要 tokenizer if tokenizer is None: raise ValueError("Tokenizer is required for lyrics reward functions.") xmlcount_scores = xmlcount_reward_func(sub_completions) soft_scores = soft_format_reward_func(sub_completions) strict_scores = strict_format_reward_func(sub_completions) length_scores = length_acc_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) bleu_scores = bleu_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) ter_scores = ter_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) lang_scores = language_recognition(sub_prompts, sub_completions, sub_answers) repetition_scores = reduce_repetition_reward_func(sub_prompts, sub_completions, sub_answers) # 简单加和示例 sub_final = [ x + s + st + l + b + t + lg + rr for x, s, st, l, b, t, lg, rr in zip( xmlcount_scores, soft_scores, strict_scores, length_scores, bleu_scores, ter_scores, lang_scores, repetition_scores ) ] for idx, score_val in zip(lyrics_indices, sub_final): final_rewards[idx] += score_val # 收集 collect_score("lyrics_xmlcount", xmlcount_scores) collect_score("lyrics_soft", soft_scores) collect_score("lyrics_strict", strict_scores) collect_score("lyrics_length", length_scores) collect_score("lyrics_bleu", bleu_scores) collect_score("lyrics_ter", ter_scores) collect_score("lyrics_lang", lang_scores) collect_score("lyrics_rep", repetition_scores) # 额外:对所有样本都计算“恐惧/矛盾/好奇心/新颖度”等通用函数 fear_scores = fear_of_unknown_reward_func(prompts, completions, answers) contradiction_scores = contradiction_penalty_func(prompts, completions, answers) curiosity_scores = curiosity_question_reward_func(prompts, completions, answers) if tokenizer is None: novelty_scores = [0.0]*batch_size else: novelty_scores = novelty_reward_func(prompts, completions, answers, tokenizer=tokenizer) # 将它们加到 final_rewards for i in range(batch_size): final_rewards[i] += fear_scores[i] final_rewards[i] += contradiction_scores[i] final_rewards[i] += curiosity_scores[i] final_rewards[i] += novelty_scores[i] # 同样可收集 collect_score("fear_of_unknown", fear_scores) collect_score("contradiction", contradiction_scores) collect_score("curiosity", curiosity_scores) collect_score("novelty", novelty_scores) # 额外可根据某些特征做进一步动态调整,比如说如果“恐惧”函数分数特别低,可降低别的奖励等等 # 这里只做简单演示,略。 return final_rewards
def evaluate_model(trainer, test_dataset, tokenizer):
"""Evaluate the model on test_dataset using BLEU, TER, and length_acc (示例)"""
print("\nBegin Evaluation on Test Dataset...")
textmodel = trainer.model model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) predictions = [] references = [] for example in test_dataset: prompt = example["prompt"] true_answer = example["answer"] inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) input_ids = inputs.input_ids.to(device) with torch.no_grad(): generated_ids = model.generate(input_ids, max_length=512) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) predictions.append(generated_text) references.append(true_answer) # 提取生成的答案 pred_answers = [extract_xml_answer(pred) for pred in predictions] # 计算评估指标 (仅示例:BLEU, TER, length_acc) bleu_scores = compute_bleu(pred_answers, references, tokenizer=tokenizer) ter_scores = compute_ter(pred_answers, references, tokenizer=tokenizer) length_accs = compute_length_acc(pred_answers, references, tokenizer=tokenizer) # 记录结果 log_results({ "BLEU": bleu_scores, "TER": ter_scores, "LENGTH_ACCURACY": length_accs, }, trainer=trainer) print("Test evaluation completed.")
def log_results(results: dict, trainer=None):
"""Log results to console (以及WandB等,可自行扩展)"""
for metric, scores in results.items():
avg_score = sum(scores) / len(scores) if scores else 0.0
print(f"[Test] {metric}: Avg {avg_score:.4f}")
if trainer and trainer.args.report_to:
trainer.log({f"{metric}/avg": avg_score})
print("-" * 40)
def build_model_and_tokenizer(model_path="../model/Qwen2.5-3B-Instruct", max_seq_length=1536, lora_rank=16):
"""
Load a Qwen (or other) model via FastLanguageModel with LoRA.
Return (model, tokenizer).
"""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
load_in_4bit=True,
fast_inference=True,
max_lora_rank=lora_rank,
gpu_memory_utilization=0.7
)
textmodel = FastLanguageModel.get_peft_model( model, r=lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_rank, use_gradient_checkpointing="unsloth", random_state=3407, ) return model, tokenizer
PatchFastRL("GRPO", FastLanguageModel)
print("Detected GPU count:", torch.cuda.device_count())
model, tokenizer = build_model_and_tokenizer()
gsm8k_dataset = get_gsm8k_questions(split="train")
training_args_math = GRPOConfig(
use_vllm=True,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
num_generations=16,
max_prompt_length=768,
max_completion_length=768,
num_train_epochs=1,
report_to = "none",
output_dir = "outputs_math_phase"
)
trainer_math = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[dynamic_reward_func], # 只需将我们自定义的 dynamic_reward_func 传递进去
args=training_args_math,
train_dataset=gsm8k_dataset,
)
print("=== Starting Phase 1: Math Training (GSM8K) ===")
trainer_math.train()
print("=== Finished Phase 1 ===\n")
dataset_dict = get_lyric_datasets("../data_pack/multi_lyric.json", train_test_ratio=0.8)
train_dataset_lyrics = dataset_dict['train']
test_dataset_lyrics = dataset_dict['test']
training_args_lyrics = GRPOConfig(
use_vllm = False,
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 16,
gradient_accumulation_steps = 2,
num_generations = 8,
max_prompt_length = 768,
max_completion_length = 768,
num_train_epochs = 2,
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none",
output_dir = "outputs_lyrics_phase",
)
trainer_lyrics = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[dynamic_reward_func],
args=training_args_lyrics,
train_dataset=train_dataset_lyrics,
)
print("=== Starting Phase 2: Lyrics Training ===")
trainer_lyrics.train()
print("=== Finished Phase 2 ===\n")
evaluate_model(trainer_lyrics, test_dataset_lyrics, tokenizer=tokenizer)
print("All done.")
强化学习奖励阶段不应该只有一个标量值来进行简单的reward奖励,还可以有其他更具多样性的互动,请你结合上面的想法,为我设计出更合理的强化学习互动策略,包括(奖励,动作,优势函数评估等等用于更新网络模型)!给出完整的代码,不要有任何的省略。
下面给出一个更完整、带有多样化交互策略的强化学习示例。相较于之前只对最终输出给一个标量的简单做法,本示例将展示如何在训练过程中:
dynamic_reward_func 做整体的质量评分,再结合一些额外交互(如人工反馈、交互式打分、或对环境中其他信息的采样) 来组成总奖励。注意:此示例在结构上仍然沿用之前的思路(TRL库中的PPO/GRPO流程),只是在逻辑上补充了更多「环境-智能体」的拆分、rollout缓存、优势计算的过程示范,以帮助你理解和扩展。实际使用中可根据自己的需求精简或修改。
下面的代码是一个完整示例,从头到尾包括:
请注意,由于此示例会同时包含前面所涉及的各项奖励函数以及额外的「环境/智能体/优势函数」逻辑,会比较长。但为了**“不要有任何的省略”**,我会将其完整呈现:
python# -*- coding: utf-8 -*- import torch import re import random import langid from typing import List, Dict, Any import math import nltk from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from nltk import edit_distance # 如果你没有安装 nltk 或 langid,需要先安装: # pip install nltk langid # 如果你没有安装 datasets,需要先安装: # pip install datasets from datasets import load_dataset, Dataset, concatenate_datasets # 如果你没有安装 trl,需要先安装(确保版本>=0.4): # pip install trl from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead, AutoTokenizer # 或者如果你在使用 GRPO: # from trl import GRPOConfig, GRPOTrainer # 如果你有 wandb,需要的话可以导入并在训练中使用: import wandb import os # ============== # (可选)仅用于指定GPU。如果你的环境无需此操作可注释 os.environ["CUDA_VISIBLE_DEVICES"] = "0" ####################### # 自定义一些配置常量 ####################### SUPPORTED_LANGUAGES = { "en_to_zh": ("英语", "中文"), "zh_to_en": ("中文", "英语"), "jp_to_zh": ("日语", "中文"), "zh_to_azh": ("中文", "中文(文言文)"), "en_to_azh": ("英语", "中文(文言文)"), "de_to_zh": ("德语", "中文"), "italy_to_zh": ("意大利语", "中文"), "rus_to_zh": ("俄语", "中文"), "fr_to_zh": ("法语", "中文"), "mix": ("中文", "multilingual") } SYSTEM_PROMPT = """\ You are a versatile assistant. Follow these rules: 1. **Always** respond in this exact XML format: <reasoning> [Your reasoning here] </reasoning> <answer> [Final answer or translation here] </answer> 2. If given lyrics, translate them from {} to {} only, without extra explanations. 3. If given a math question, solve it and provide the result in the same XML format. """ XML_COT_FORMAT = """\ <reasoning> {reasoning} </reasoning> <answer> {answer} </answer> """ ####################### # 数据准备函数 ####################### def get_lyric_datasets(path: str, train_test_ratio=0.8) -> dict: """ 加载一个歌词翻译数据集 JSON,并拆分成 train / test (例如80% / 20%)。 每条数据应包含 'type', 'lyric', 'target_lyric' 这些字段。 本示例在map_fn里插入了对应的系统提示 + user内容,并返回 (prompt, answer, task) """ data = Dataset.from_json(path) def map_fn(x): lang_src = SUPPORTED_LANGUAGES[x['type']][0] lang_tgt = SUPPORTED_LANGUAGES[x['type']][1] # 将正确的源语言->目标语言插入 system prompt: system_plus = SYSTEM_PROMPT.format(lang_src, lang_tgt) return { 'prompt': [ {'role': 'system', 'content': system_plus}, {'role': 'user', 'content': x['lyric']} ], 'answer': x['target_lyric'], 'task': "lyrics" } data = data.map(map_fn) return data.train_test_split(train_size=train_test_ratio, seed=42) def extract_hash_answer(text: str): """ 对于 GSM8K数据常见格式:在解析过程中提取 '####' 后的答案部分。 """ if "####" not in text: return None return text.split("####")[1].strip() def get_gsm8k_questions(split="train", sample_size=3000) -> Dataset: """ 加载 GSM8K 并统一处理 prompt、answer。仅供示例,需根据你真实路径或方式改写。 """ # 此处假设你有若干 gsm8k的parquet文件: ../data_pack/openai/gsm8k/train-*.parquet # 纯演示 data = Dataset.from_parquet(f'../data_pack/openai/gsm8k/{split}-*.parquet') data = data.select(range(min(sample_size, len(data)))) def map_fn(x): user_prompt = x['question'] final_answer = extract_hash_answer(x['answer']) return { 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT.format("N/A", "N/A")}, {'role': 'user', 'content': user_prompt} ], 'answer': final_answer if final_answer else "", 'task': "math" } data = data.map(map_fn) return data ####################### # 文本抽取相关小工具 ####################### def extract_xml_answer(text: str) -> str: """ 从模型生成的文本中提取<answer>...</answer>之间的内容。 """ pattern = r"<answer>\s*(.*?)\s*</answer>" match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() return "" def detect_language(text: str) -> str: """ langid 做简单的语言检测, 返回 'en'、'zh'、'fr'等ISO语言代码 """ return langid.classify(text)[0] ####################### # 奖励函数示例 ####################### def compute_length_acc( preds: List[str], refs: List[str], tokenizer, max_tolerance: float = 0.5 ) -> List[float]: """ 返回每个pred/ref的长度相似度奖励。 仅作为示例,可根据需求自行调整。 """ rewards = [] for pred, ref in zip(preds, refs): pred_tokens = tokenizer.tokenize(pred) ref_tokens = tokenizer.tokenize(ref) if len(ref_tokens) == 0: rewards.append(0.0) continue length_ratio = abs(len(pred_tokens) - len(ref_tokens)) / len(ref_tokens) if length_ratio <= 0.1: score = 1.0 elif length_ratio <= 0.2: score = 0.7 elif length_ratio <= 0.3: score = 0.5 elif length_ratio <= 0.4: score = 0.3 else: score = 0.0 rewards.append(score) return rewards def compute_bleu( preds: List[str], refs: List[str], tokenizer ) -> List[float]: """ 逐个pair计算BLEU(4-gram平均) """ smoothie = SmoothingFunction().method1 weights = (0.25, 0.25, 0.25, 0.25) scores = [] for pred, ref in zip(preds, refs): pred_tokens = tokenizer.tokenize(pred) ref_tokens = tokenizer.tokenize(ref) if not pred_tokens or not ref_tokens: scores.append(0.0) continue bleu = sentence_bleu( [ref_tokens], pred_tokens, weights=weights, smoothing_function=smoothie ) scores.append(bleu) return scores def compute_ter( preds: List[str], refs: List[str], tokenizer ) -> List[float]: """ 计算 Translation Edit Rate = 编辑距离 / 参考词数 * 100 """ ter_scores = [] for pred, ref in zip(preds, refs): pred_tokens = tokenizer.tokenize(pred) ref_tokens = tokenizer.tokenize(ref) if len(ref_tokens) == 0: if len(pred_tokens) > 0: ter_scores.append(100.0) else: ter_scores.append(0.0) continue dist = edit_distance(pred_tokens, ref_tokens) ter = (dist / len(ref_tokens)) * 100 ter_scores.append(ter) return ter_scores # ===== 一些基础奖励函数 ===== def strict_format_reward_func(completions, **kwargs) -> List[float]: """ 如果文本完整匹配 <reasoning>...</reasoning><answer>...</answer> 格式 => +1,否则0 """ pattern = r"<reasoning>[\s\S]+</reasoning>\s*<answer>[\s\S]+</answer>" responses = [completion[0]["content"] for completion in completions] scores = [] for r in responses: if re.search(pattern, r): scores.append(1.0) else: scores.append(0.0) return scores def soft_format_reward_func(completions, **kwargs) -> List[float]: """ 包含最基本XML标签 <reasoning>...</reasoning> <answer>...</answer> 就给0.5,否则0 """ pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" responses = [completion[0]["content"] for completion in completions] matches = [re.search(pattern, r, flags=re.DOTALL) for r in responses] return [0.5 if m else 0.0 for m in matches] def xmlcount_reward_func(completions, **kwargs) -> List[float]: """ 根据标签出现次数给分,演示用。 """ def count_xml(text) -> float: count = 0.0 if text.count("<reasoning>\n") == 1: count += 0.125 if text.count("\n</reasoning>\n") == 1: count += 0.125 if text.count("\n<answer>\n") == 1: count += 0.125 if text.count("\n</answer>") == 1: count += 0.125 return count responses = [c[0]["content"] for c in completions] return [count_xml(c) for c in responses] # ===== 数学任务的奖励函数 ===== def correctness_reward_func(prompts, completions, answers, **kwargs) -> List[float]: """ 如果提取到的 <answer> 和正确答案一致 => +2, 否则0 """ responses = [c[0]['content'] for c in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [2.0 if r.strip() == a.strip() else 0.0 for r, a in zip(extracted_responses, answers)] def int_reward_func(completions, **kwargs) -> List[float]: """ 如果 <answer> 是纯数字(可带符号),奖励+0.5,否则0 """ responses = [c[0]['content'] for c in completions] extracted_responses = [extract_xml_answer(r) for r in responses] scores = [] for r in extracted_responses: if re.match(r"^-?\d+$", r.strip()): scores.append(0.5) else: scores.append(0.0) return scores # ===== 歌词翻译任务的奖励函数 ===== def length_acc_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]: responses = [completion[0]['content'] for completion in completions] extracted = [extract_xml_answer(r) for r in responses] return compute_length_acc(preds=extracted, refs=answers, tokenizer=tokenizer) def bleu_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]: responses = [c[0]["content"] for c in completions] extracted = [extract_xml_answer(r) for r in responses] bleu_scores = compute_bleu(preds=extracted, refs=answers, tokenizer=tokenizer) rewards = [] for s in bleu_scores: if s >= 0.9: rewards.append(5.0) elif s >= 0.8: rewards.append(4.5) elif s >= 0.7: rewards.append(4.0) elif s >= 0.6: rewards.append(3.5) elif s >= 0.5: rewards.append(2.5) elif s >= 0.4: rewards.append(2.0) elif s >= 0.3: rewards.append(1.5) elif s >= 0.2: rewards.append(1.0) elif s >= 0.1: rewards.append(0.5) else: rewards.append(0.0) return rewards def ter_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]: responses = [c[0]["content"] for c in completions] extracted = [extract_xml_answer(r) for r in responses] ter_scores = compute_ter(preds=extracted, refs=answers, tokenizer=tokenizer) rewards = [] for t in ter_scores: if t >= 80: rewards.append(0.0) elif t >= 60: rewards.append(0.5) elif t >= 40: rewards.append(1.0) elif t >= 20: rewards.append(1.5) else: rewards.append(2.0) return rewards def language_recognition(prompts, completions, answers, **kwargs) -> List[float]: """ 如果生成的文本语言与参考答案语言一致 => +1,否则0 """ responses = [c[0]["content"] for c in completions] extracted = [extract_xml_answer(r) for r in responses] rewards = [] for pred, ref in zip(extracted, answers): if not pred.strip(): rewards.append(0.0) continue pred_lang = detect_language(pred) ref_lang = detect_language(ref) rewards.append(1.0 if pred_lang == ref_lang else 0.0) return rewards def reduce_repetition_reward_func(prompts, completions, answers, **kwargs) -> List[float]: """ 简单示例:歌词中若重复行过多 => 惩罚;重复行少 => 奖励 """ responses = [c[0]["content"] for c in completions] rewards = [] for resp in responses: answer_text = extract_xml_answer(resp) lines = [l.strip() for l in answer_text.split("\n") if l.strip()] if len(lines) <= 1: rewards.append(0.0) continue unique_lines = len(set(lines)) ratio = unique_lines / len(lines) reward_val = 1.0 if ratio >= 0.8 else 0.0 rewards.append(reward_val) return rewards # ===== 其他自定义函数 ===== def fear_of_unknown_reward_func(prompts, completions, answers, **kwargs) -> List[float]: """ 如果回答中出现“不确定”、“不知道”、“???”等 => 负奖励 """ responses = [c[0]["content"] for c in completions] triggers = ["i don't know", "i do not know", "not sure", "???", "no idea", "不确定", "不知道"] scores = [] for resp in responses: lower_resp = resp.lower() penalty = 0.0 for t in triggers: if t in lower_resp: penalty -= 1.0 # 每出现一次触发词,-1分 scores.append(penalty) return scores def contradiction_penalty_func(prompts, completions, answers, **kwargs) -> List[float]: """ 若同一回答中出现明显矛盾(如"yes"和"no") => 惩罚 """ contradictory_pairs = [ ("yes", "no"), ("true", "false"), ("can", "cannot") ] responses = [c[0]["content"].lower() for c in completions] scores = [] for r in responses: penalty = 0.0 for pair in contradictory_pairs: if pair[0] in r and pair[1] in r: penalty -= 2.0 scores.append(penalty) return scores def curiosity_question_reward_func(prompts, completions, answers, **kwargs) -> List[float]: """ 若 <reasoning> 中出现对自身问题的好奇(带'?') => 正向奖励 """ responses = [c[0]["content"] for c in completions] pattern = r"<reasoning>(.*?)</reasoning>" scores = [] for resp in responses: match = re.search(pattern, resp, flags=re.DOTALL | re.IGNORECASE) if not match: scores.append(0.0) continue reasoning_text = match.group(1) question_count = reasoning_text.count('?') reward = min(question_count * 0.2, 1.0) scores.append(reward) return scores def novelty_reward_func(prompts, completions, answers, tokenizer=None, **kwargs) -> List[float]: """ 统计与用户prompt重叠的词,用于鼓励一定程度新意 """ user_texts = [] for p in prompts: user_texts.append(p[-1]['content']) responses = [c[0]["content"] for c in completions] scores = [] for user_t, resp in zip(user_texts, responses): user_tokens = set(tokenizer.tokenize(user_t.lower())) model_answer_text = extract_xml_answer(resp).lower() model_tokens = tokenizer.tokenize(model_answer_text) if not model_tokens: scores.append(0.0) continue new_words = [w for w in model_tokens if w not in user_tokens] ratio = len(new_words) / len(model_tokens) if ratio > 0.8: scores.append(1.0) elif ratio > 0.5: scores.append(0.5) elif ratio > 0.2: scores.append(0.2) else: scores.append(0.0) return scores # ================ # 重点:动态奖励 # ================ def dynamic_reward_func( prompts, completions, answers, tokenizer=None, batch_metadata=None, # 如果在Trainer中可以传递额外元数据 **kwargs ) -> List[float]: """ 以任务类型、恐惧检测结果、好奇程度等为依据,动态决定如何组合子奖励函数。 """ batch_size = len(completions) final_rewards = [0.0] * batch_size # 我们先推断哪些是math任务,哪些是lyrics任务 tasks = [] if batch_metadata is not None: tasks = [meta.get('task', None) for meta in batch_metadata] else: # 不可用就简单从 prompts 猜测 for i in range(batch_size): user_text = prompts[i][-1]['content'].lower() if "math" in user_text or "question" in user_text: tasks.append("math") else: tasks.append("lyrics") math_indices = [i for i,t in enumerate(tasks) if t == "math"] lyrics_indices = [i for i,t in enumerate(tasks) if t == "lyrics"] # 对 math 的样本计算 if math_indices: sub_prompts = [prompts[i] for i in math_indices] sub_completions = [completions[i] for i in math_indices] sub_answers = [answers[i] for i in math_indices] c_scores = correctness_reward_func(sub_prompts, sub_completions, sub_answers) i_scores = int_reward_func(sub_completions) s_scores = strict_format_reward_func(sub_completions) so_scores = soft_format_reward_func(sub_completions) sub_final = [ c + i + s + so for c, i, s, so in zip(c_scores, i_scores, s_scores, so_scores) ] for idx, score_val in zip(math_indices, sub_final): final_rewards[idx] += score_val # 对 lyrics 的样本计算 if lyrics_indices: sub_prompts = [prompts[i] for i in lyrics_indices] sub_completions = [completions[i] for i in lyrics_indices] sub_answers = [answers[i] for i in lyrics_indices] if tokenizer is None: raise ValueError("Tokenizer is required for lyrics reward functions.") x_scores = xmlcount_reward_func(sub_completions) so_scores = soft_format_reward_func(sub_completions) st_scores = strict_format_reward_func(sub_completions) l_scores = length_acc_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) b_scores = bleu_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) t_scores = ter_reward_func(sub_prompts, sub_completions, sub_answers, tokenizer=tokenizer) lang_scores = language_recognition(sub_prompts, sub_completions, sub_answers) rep_scores = reduce_repetition_reward_func(sub_prompts, sub_completions, sub_answers) sub_final = [ x + so + st + l + b + t + la + r for x, so, st, l, b, t, la, r in zip( x_scores, so_scores, st_scores, l_scores, b_scores, t_scores, lang_scores, rep_scores ) ] for idx, score_val in zip(lyrics_indices, sub_final): final_rewards[idx] += score_val # 对所有样本添加通用奖励或惩罚 fear_scores = fear_of_unknown_reward_func(prompts, completions, answers) contra_scores = contradiction_penalty_func(prompts, completions, answers) curious_scores = curiosity_question_reward_func(prompts, completions, answers) novelty_scores = [0.0]*batch_size if tokenizer is not None: novelty_scores = novelty_reward_func(prompts, completions, answers, tokenizer=tokenizer) for i in range(batch_size): final_rewards[i] += fear_scores[i] final_rewards[i] += contra_scores[i] final_rewards[i] += curious_scores[i] final_rewards[i] += novelty_scores[i] return final_rewards ############################## # 环境(Env) 与 智能体(Agent) ############################## class TextRLTaskEnvironment: """ 一个简化的示例环境。它会: 1. 提供一个(batch)的 "state" (即 prompt) 给智能体 2. 智能体生成动作(即输出文本) 3. 环境将 "完成的对话" + "参考答案" 传入奖励函数,计算reward 4. 返回 (obs, reward, done, info) 这里我们一次生成完整回复(而非token-level循环),以示例如何与上面的 reward_func 集成。 更细粒度的token-level循环也可实现,但需要在TRL中自定义更多组件。 """ def __init__(self, dataset, tokenizer, reward_func): self.dataset = dataset self.tokenizer = tokenizer self.reward_func = reward_func def reset(self, indices): """ 给定一批索引,从dataset里取相应的数据,并视为state 这里只返回 'prompt' 与 'answer' 做参考 """ batch = [self.dataset[i] for i in indices] # obs可以是prompt, 以及其它元数据 obs = [ex["prompt"] for ex in batch] references = [ex["answer"] for ex in batch] # 也可以把task等信息当作metadata meta = batch return obs, references, meta def step(self, prompts, completions, references, meta): """ 计算奖励 prompts: List[List[Dict]] # role-based prompt completions: List[List[Dict]] # role=assistant, content=xxx references: List[str] meta: List[dict] # 里面包含task等 """ # reward_func需要的入参有 prompts, completions, answers (就是references) # 还可以传 batch_metadata=meta rewards = self.reward_func(prompts, completions, references, tokenizer=self.tokenizer, batch_metadata=meta) # done:这里一次生成就结束 done = [True]*len(prompts) # info可返回一些诊断信息 info = {} return rewards, done, info class TextRLLanguageAgent: """ 简化示例:使用一个Policy+Value网络(AutoModelForCausalLMWithValueHead)进行文本生成。 也可以对应到TRLError/PPOTrainer中处理。 """ def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.model.eval() def act(self, prompts, max_new_tokens=128): """ 给定一批prompt,使用当前策略生成回复。 这里演示一次性生成文本。更精细的 token-level act 也可实现。 """ # TRL中的PPOTrainer会对生成进行封装,这里只是演示 completions = [] device = next(self.model.parameters()).device for p in prompts: # 将 role-based prompt 拼到一起 text_input = "" for pm in p: if pm['role'] == 'system': text_input += f"[system] {pm['content']}\n" elif pm['role'] == 'user': text_input += f"[user] {pm['content']}\n" else: text_input += f"[{pm['role']}] {pm['content']}\n" inputs = self.tokenizer(text_input, return_tensors="pt").to(device) with torch.no_grad(): output_ids = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, top_k=50, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id ) gen_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) # 这里简单取生成文本中 prompt 之后的新内容: completion_str = gen_text[len(text_input):].strip() completions.append([{'role': 'assistant', 'content': completion_str}]) return completions ################################## # 示范一个 PPO/GRPO 训练流程 ################################## def evaluate_model(trainer, test_dataset, tokenizer): """ 在测试集上做推断和计算一些指标 (例如 BLEU, TER, length_acc) """ print("\nBegin Evaluation on Test Dataset...") model = trainer.model model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) predictions = [] references = [] for example in test_dataset: prompt = example["prompt"] true_answer = example["answer"] # 将role-based prompt拼合 text_input = "" for pm in prompt: text_input += f"{pm['role']}:{pm['content']}\n" inputs = tokenizer(text_input, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=256) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) predictions.append(generated_text) references.append(true_answer) # 提取 <answer> pred_answers = [extract_xml_answer(pred) for pred in predictions] # 计算评估指标 bleu_scores = compute_bleu(pred_answers, references, tokenizer=tokenizer) ter_scores = compute_ter(pred_answers, references, tokenizer=tokenizer) length_accs = compute_length_acc(pred_answers, references, tokenizer=tokenizer) # 记录结果 results = { "BLEU": bleu_scores, "TER": ter_scores, "LENGTH_ACCURACY": length_accs, } for metric, scores in results.items(): avg_score = sum(scores) / len(scores) if scores else 0.0 print(f"[Test] {metric}: Avg {avg_score:.4f}") print("Test evaluation completed.") def main(): # 1. 准备模型 & tokenizer # 这里示例使用HF任意的CausalLM模型; 也可以使用Qwen等,只要能兼容PPOTrainer model_name = "gpt2" # 仅作演示 tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 在trl库里,PPO需要一个带价值头(ValueHead)的模型 # 也可以使用 GRPO: from trl import GRPOTrainer, GRPOConfig, ... model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) # 2. 准备数据:math 和 lyrics 的例子(自行替换路径) print("Loading GSM8K dataset for math tasks...") gsm8k_dataset = get_gsm8k_questions(split="train", sample_size=3000) print("Loading lyric dataset for translation tasks...") lyric_dataset_dict = get_lyric_datasets("../data_pack/multi_lyric.json", train_test_ratio=0.8) train_dataset_lyrics = lyric_dataset_dict['train'] test_dataset_lyrics = lyric_dataset_dict['test'] # 3. 合并数据(这里为了示例,将math + lyrics合到一起训练) # 也可以分阶段先训math再训lyrics full_train_data = concatenate_datasets([gsm8k_dataset, train_dataset_lyrics]) # 4. PPO 配置 ppo_config = PPOConfig( batch_size=16, forward_batch_size=4, learning_rate=5e-6, log_with=None, # 若你想用wandb,可以设置 log_with="wandb" optimize_cuda_graphs=False, early_stopping=False, num_epochs=1, remove_unused_columns=False, ) # 5. 构建 PPOTrainer # 这里把我们的 dynamic_reward_func 直接作为一个「自定义Reward」去应用 ppo_trainer = PPOTrainer( config=ppo_config, model=model, tokenizer=tokenizer ) # 6. 环境 + Agent (如果你希望自己管理数据迭代/生成/奖励计算) # 也可以直接使用 ppo_trainer自带的 generate & step等方式,这里更显式地做演示 environment = TextRLTaskEnvironment( dataset=full_train_data, tokenizer=tokenizer, reward_func=dynamic_reward_func ) agent = TextRLLanguageAgent(model=model, tokenizer=tokenizer) # 7. 训练循环 (简化演示) num_train_episodes = 100 # 训练回合数 dataset_size = len(full_train_data) # PPOTrainer 自带一个 rollout buffer;这里手动管理一些收集过程 for episode in range(num_train_episodes): # 这里随机抽 batch_size 条数据 # ppo_config.batch_size 是指 PPO 的一次训练batch大小 indices = random.sample(range(dataset_size), ppo_config.batch_size) prompts, references, meta = environment.reset(indices) # 智能体生成动作(完整回复) completions = agent.act(prompts) # 计算奖励 rewards, done, info = environment.step(prompts, completions, references, meta) # 现在我们要将 prompts + completions(动作) + rewards 交给 ppo_trainer # ppo_trainer会做 Advantage 估计、价值函数更新、策略梯度更新 texts = [] for i in range(len(prompts)): # 把 prompt + completion 拼起来 prompt_text = "" for pm in prompts[i]: prompt_text += pm['content'] + "\n" completion_text = completions[i][0]['content'] texts.append(prompt_text + completion_text) # PPOTrainer需要 (queries, responses, rewards) # queries: prompt部分 # responses: 完整生成(不包含prompt) # 但也可以把 prompts + completions 全部作为 response, 只要保持一致性 # 这里我们简化: queries = prompt_text, responses = completion_text query_texts = [p[-1]['content'] for p in prompts] # 只拿user问的那部分 response_texts = [c[0]['content'] for c in completions] # ppo step train_stats = ppo_trainer.step( queries=query_texts, responses=response_texts, rewards=rewards ) # 你也可以查看 train_stats 内部的一些损失和优势信息 if (episode+1) % 10 == 0: print(f"Episode {episode+1}/{num_train_episodes} done. avg reward={sum(rewards)/len(rewards):.4f}") print("=== PPO Training Finished ===") # 训练完可以再测试: evaluate_model(ppo_trainer, test_dataset_lyrics, tokenizer) if __name__ == "__main__": main()
环境 (TextRLTaskEnvironment)
reset(indices):根据给定索引,从数据集中取对应样本,返回 prompts(即多轮对话的上下文,如system+user)、references(参考答案/期望输出,用于部分奖励函数),以及 meta(包含额外的信息如task类型)。step(...):将智能体生成的回复 completions 同 prompts 一起送入 reward_func,得到奖励列表 rewards,并返回 (rewards, done, info)。在这个示例里,一次生成就结束,所以 done 全为 True。智能体 (TextRLLanguageAgent)
prompts 并生成 completions。PPOTrainer
trl 库的核心类。内部会构建一个 rollout buffer 并执行 PPO 算法的更新:
rewards 估计优势(Advantage)优势函数(Advantage) 计算
trl 库中自动完成。它会基于每个生成的 step (通常是token-level) 计算 (Q - V) 或 GAE(广义优势估计) 等。动作 (Action)
trl 提供的token-level PPO流程,但示例中为简化只展示概念。奖励 (Reward)
dynamic_reward_func 返回的标量列表 rewards。pythonfinal_rewards[i] = auto_score + human_score
rewards。完整性
GRPO (只需把 PPOTrainer/PPOConfig 替换成 GRPOTrainer/GRPOConfig 并修改相关超参即可)。这样,你就获得了一个「更合理的强化学习互动策略」示例:通过环境和智能体的明确定义来分工、利用动作(生成文本)与奖励函数(组合多种评价维度)得到回报、最终通过优势函数(PPO内部计算)来更新策略网络。希望对你有所帮助!