DPO微调

什么是DPO?

DPO(Direct Preference Optimization)是一种直接偏好优化方法,通过对比学习的方式训练模型,使其更偏好人类标注的”好”回答,而不需要训练独立的奖励模型。

核心原理

传统RLHF vs DPO

传统RLHF:SFT → 训练Reward模型 → PPO优化
DPO:SFT → 直接偏好优化(跳过Reward模型)

数学原理

DPO基于Bradley-Terry模型,直接优化偏好概率:

P(y_w > y_l | x) = σ(β log π_θ(y_w|x)/π_ref(y_w|x) - β log π_θ(y_l|x)/π_ref(y_l|x))

其中:

  • y_w:偏好的回答(win)
  • y_l:不偏好的回答(lose)
  • π_θ:当前策略模型
  • π_ref:参考模型(通常是SFT模型)
  • β:温度参数

技术实现

DPO损失函数

import torch
import torch.nn.functional as F
 
def dpo_loss(policy_chosen_logps, policy_rejected_logps, 
             reference_chosen_logps, reference_rejected_logps, 
             beta=0.1):
    """
    计算DPO损失
    
    Args:
        policy_chosen_logps: 策略模型对偏好回答的log概率
        policy_rejected_logps: 策略模型对拒绝回答的log概率  
        reference_chosen_logps: 参考模型对偏好回答的log概率
        reference_rejected_logps: 参考模型对拒绝回答的log概率
        beta: 温度参数
    """
    
    # 计算log比率
    policy_ratio_chosen = policy_chosen_logps - reference_chosen_logps
    policy_ratio_rejected = policy_rejected_logps - reference_rejected_logps
    
    # DPO损失
    logits = beta * (policy_ratio_chosen - policy_ratio_rejected)
    loss = -F.logsigmoid(logits).mean()
    
    # 计算准确率(偏好回答的概率是否更高)
    accuracy = (logits > 0).float().mean()
    
    return loss, accuracy
 
def compute_log_probs(model, input_ids, attention_mask, labels):
    """计算序列的log概率"""
    
    with torch.no_grad() if model.training else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        # 计算每个token的log概率
        log_probs = F.log_softmax(logits, dim=-1)
        
        # 只计算标签部分的概率
        labels = labels.clone()
        labels[labels == -100] = 0  # 忽略的token设为0
        
        # 收集对应token的log概率
        gathered_log_probs = torch.gather(
            log_probs, dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)
        
        # 只对非忽略的token求和
        mask = (labels != -100).float()
        sequence_log_prob = (gathered_log_probs * mask).sum(dim=-1)
        
        return sequence_log_prob

DPO训练器

from transformers import Trainer
import torch.nn as nn
 
class DPOTrainer(Trainer):
    def __init__(self, model, ref_model, beta=0.1, **kwargs):
        super().__init__(model=model, **kwargs)
        self.ref_model = ref_model
        self.beta = beta
        
        # 冻结参考模型
        for param in self.ref_model.parameters():
            param.requires_grad = False
        self.ref_model.eval()
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """计算DPO损失"""
        
        # 解析输入
        chosen_input_ids = inputs["chosen_input_ids"]
        chosen_attention_mask = inputs["chosen_attention_mask"] 
        chosen_labels = inputs["chosen_labels"]
        
        rejected_input_ids = inputs["rejected_input_ids"]
        rejected_attention_mask = inputs["rejected_attention_mask"]
        rejected_labels = inputs["rejected_labels"]
        
        # 策略模型前向传播
        policy_chosen_logps = compute_log_probs(
            model, chosen_input_ids, chosen_attention_mask, chosen_labels
        )
        policy_rejected_logps = compute_log_probs(
            model, rejected_input_ids, rejected_attention_mask, rejected_labels
        )
        
        # 参考模型前向传播
        with torch.no_grad():
            reference_chosen_logps = compute_log_probs(
                self.ref_model, chosen_input_ids, chosen_attention_mask, chosen_labels
            )
            reference_rejected_logps = compute_log_probs(
                self.ref_model, rejected_input_ids, rejected_attention_mask, rejected_labels
            )
        
        # 计算DPO损失
        loss, accuracy = dpo_loss(
            policy_chosen_logps, policy_rejected_logps,
            reference_chosen_logps, reference_rejected_logps,
            beta=self.beta
        )
        
        # 记录指标
        self.log({
            "train/dpo_loss": loss.item(),
            "train/accuracy": accuracy.item(),
            "train/chosen_logps": policy_chosen_logps.mean().item(),
            "train/rejected_logps": policy_rejected_logps.mean().item(),
        })
        
        return loss

数据准备

偏好数据格式

# DPO训练数据格式
dpo_data_example = {
    "prompt": "请解释什么是机器学习",
    "chosen": "机器学习是一种人工智能技术,通过算法让计算机从数据中学习模式,无需明确编程即可做出预测或决策。它包括监督学习、无监督学习和强化学习等方法。",
    "rejected": "机器学习就是让机器变聪明的技术。"
}
 
def format_dpo_data(examples, tokenizer, max_length=512):
    """格式化DPO训练数据"""
    
    batch_size = len(examples["prompt"])
    formatted_data = {
        "chosen_input_ids": [],
        "chosen_attention_mask": [],
        "chosen_labels": [],
        "rejected_input_ids": [],
        "rejected_attention_mask": [],
        "rejected_labels": []
    }
    
    for i in range(batch_size):
        prompt = examples["prompt"][i]
        chosen = examples["chosen"][i]
        rejected = examples["rejected"][i]
        
        # 格式化chosen样本
        chosen_text = f"{prompt}\n{chosen}"
        chosen_encoded = tokenizer(
            chosen_text,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
        
        # 格式化rejected样本
        rejected_text = f"{prompt}\n{rejected}"
        rejected_encoded = tokenizer(
            rejected_text,
            truncation=True,
            padding="max_length", 
            max_length=max_length,
            return_tensors="pt"
        )
        
        # 创建labels(只对回答部分计算损失)
        prompt_length = len(tokenizer(prompt)["input_ids"])
        
        chosen_labels = chosen_encoded["input_ids"].clone()
        chosen_labels[:, :prompt_length] = -100  # 忽略prompt部分
        
        rejected_labels = rejected_encoded["input_ids"].clone()
        rejected_labels[:, :prompt_length] = -100  # 忽略prompt部分
        
        # 添加到batch
        formatted_data["chosen_input_ids"].append(chosen_encoded["input_ids"])
        formatted_data["chosen_attention_mask"].append(chosen_encoded["attention_mask"])
        formatted_data["chosen_labels"].append(chosen_labels)
        
        formatted_data["rejected_input_ids"].append(rejected_encoded["input_ids"])
        formatted_data["rejected_attention_mask"].append(rejected_encoded["attention_mask"])
        formatted_data["rejected_labels"].append(rejected_labels)
    
    # 转换为tensor
    for key in formatted_data:
        formatted_data[key] = torch.cat(formatted_data[key], dim=0)
    
    return formatted_data

数据收集策略

def collect_preference_data():
    """偏好数据收集策略"""
    
    strategies = {
        "人工标注": {
            "描述": "人工评估员对比两个回答并选择更好的",
            "优点": "质量高、可靠性强",
            "缺点": "成本高、规模有限",
            "适用场景": "高质量基准数据集"
        },
        
        "模型生成对比": {
            "描述": "使用不同模型生成回答,选择更好的作为chosen",
            "优点": "规模大、成本低",
            "缺点": "质量可能不稳定",
            "适用场景": "大规模训练数据"
        },
        
        "自我对比": {
            "描述": "同一模型生成多个回答,选择最好的",
            "优点": "一致性好、易于实现",
            "缺点": "多样性有限",
            "适用场景": "自我改进训练"
        },
        
        "规则筛选": {
            "描述": "基于长度、安全性等规则筛选",
            "优点": "可控性强、可解释",
            "缺点": "可能过于简单",
            "适用场景": "特定约束优化"
        }
    }
    
    return strategies

训练流程

完整DPO训练

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import Dataset
 
def train_dpo_model(model_name, train_data, eval_data):
    """完整的DPO训练流程"""
    
    # 加载模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # 加载策略模型
    policy_model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # 加载参考模型(通常是SFT后的模型)
    reference_model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # 准备数据
    train_dataset = Dataset.from_list(train_data)
    eval_dataset = Dataset.from_list(eval_data)
    
    # 数据预处理
    def preprocess_function(examples):
        return format_dpo_data(examples, tokenizer)
    
    train_dataset = train_dataset.map(
        preprocess_function, 
        batched=True,
        remove_columns=train_dataset.column_names
    )
    
    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True, 
        remove_columns=eval_dataset.column_names
    )
    
    # 训练参数
    training_args = TrainingArguments(
        output_dir="./dpo_output",
        num_train_epochs=3,
        per_device_train_batch_size=2,  # DPO需要较小的batch size
        gradient_accumulation_steps=8,
        learning_rate=5e-7,  # DPO使用很小的学习率
        warmup_ratio=0.1,
        logging_steps=10,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        remove_unused_columns=False,
    )
    
    # 创建DPO训练器
    trainer = DPOTrainer(
        model=policy_model,
        ref_model=reference_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        beta=0.1  # DPO温度参数
    )
    
    # 开始训练
    trainer.train()
    
    # 保存模型
    trainer.save_model()
    
    return policy_model

与LoRA结合

from peft import LoraConfig, get_peft_model
 
def train_dpo_with_lora(base_model_name, train_data):
    """DPO + LoRA训练"""
    
    # 加载基础模型
    model = AutoModelForCausalLM.from_pretrained(base_model_name)
    
    # 配置LoRA
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    # 应用LoRA到策略模型
    policy_model = get_peft_model(model, lora_config)
    
    # 参考模型保持原始状态
    reference_model = AutoModelForCausalLM.from_pretrained(base_model_name)
    
    # 其余训练流程相同...
    return train_dpo_model_with_peft(policy_model, reference_model, train_data)

高级技术

自适应β调整

class AdaptiveDPOTrainer(DPOTrainer):
    def __init__(self, *args, initial_beta=0.1, beta_schedule="linear", **kwargs):
        super().__init__(*args, beta=initial_beta, **kwargs)
        self.initial_beta = initial_beta
        self.beta_schedule = beta_schedule
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # 根据训练进度调整β
        current_step = self.state.global_step
        total_steps = self.state.max_steps
        
        if self.beta_schedule == "linear":
            # 线性衰减
            self.beta = self.initial_beta * (1 - current_step / total_steps)
        elif self.beta_schedule == "cosine":
            # 余弦衰减
            import math
            self.beta = self.initial_beta * 0.5 * (1 + math.cos(math.pi * current_step / total_steps))
        
        return super().compute_loss(model, inputs, return_outputs)

多轮DPO

def iterative_dpo_training(model, initial_data, num_iterations=3):
    """迭代DPO训练"""
    
    current_model = model
    
    for iteration in range(num_iterations):
        print(f"DPO迭代 {iteration + 1}/{num_iterations}")
        
        # 使用当前模型生成新的对比数据
        new_data = generate_comparison_data(current_model, initial_data)
        
        # 合并新旧数据
        combined_data = initial_data + new_data
        
        # 进行DPO训练
        current_model = train_dpo_model(current_model, combined_data)
        
        # 评估模型性能
        performance = evaluate_model(current_model)
        print(f"迭代 {iteration + 1} 性能: {performance}")
        
        # 如果性能不再提升,提前停止
        if iteration > 0 and performance <= previous_performance:
            print("性能不再提升,提前停止")
            break
        
        previous_performance = performance
    
    return current_model

评估指标

DPO特定指标

def evaluate_dpo_model(model, ref_model, eval_data, tokenizer):
    """评估DPO模型性能"""
    
    metrics = {
        "preference_accuracy": 0,
        "kl_divergence": 0,
        "response_length": 0,
        "diversity": 0
    }
    
    total_samples = len(eval_data)
    
    for sample in eval_data:
        prompt = sample["prompt"]
        chosen = sample["chosen"]
        rejected = sample["rejected"]
        
        # 计算偏好准确率
        chosen_score = compute_response_score(model, prompt, chosen)
        rejected_score = compute_response_score(model, prompt, rejected)
        
        if chosen_score > rejected_score:
            metrics["preference_accuracy"] += 1
        
        # 计算KL散度(与参考模型的差异)
        kl_div = compute_kl_divergence(model, ref_model, prompt)
        metrics["kl_divergence"] += kl_div
        
        # 生成回答并评估
        generated_response = generate_response(model, prompt)
        metrics["response_length"] += len(generated_response.split())
    
    # 计算平均值
    for key in metrics:
        metrics[key] /= total_samples
    
    return metrics
 
def compute_response_score(model, prompt, response):
    """计算回答的分数"""
    
    full_text = f"{prompt}\n{response}"
    inputs = tokenizer(full_text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        log_probs = F.log_softmax(outputs.logits, dim=-1)
        
        # 计算序列概率
        sequence_log_prob = compute_log_probs(
            model, inputs["input_ids"], inputs["attention_mask"], inputs["input_ids"]
        )
    
    return sequence_log_prob.item()

人工评估

def human_evaluation_framework():
    """人工评估框架"""
    
    evaluation_criteria = {
        "有用性": {
            "描述": "回答是否有助于解决用户问题",
            "评分": "1-5分",
            "权重": 0.4
        },
        
        "准确性": {
            "描述": "回答是否事实正确",
            "评分": "1-5分", 
            "权重": 0.3
        },
        
        "安全性": {
            "描述": "回答是否安全无害",
            "评分": "1-5分",
            "权重": 0.2
        },
        
        "流畅性": {
            "描述": "回答是否自然流畅",
            "评分": "1-5分",
            "权重": 0.1
        }
    }
    
    return evaluation_criteria

优势与局限

优势

  1. 简化流程:无需训练独立的奖励模型
  2. 稳定训练:避免了RLHF的不稳定性
  3. 直接优化:直接优化人类偏好
  4. 计算效率:相比RLHF计算开销更小
  5. 易于实现:实现相对简单

局限性

  1. 数据依赖:需要高质量的偏好对比数据
  2. 偏好固化:可能过度拟合训练数据的偏好
  3. 多样性下降:可能降低回答的多样性
  4. 复杂偏好:难以处理复杂的多维偏好
  5. 分布偏移:可能偏离原始模型分布

最佳实践

超参数调优

def dpo_hyperparameter_guide():
    """DPO超参数调优指南"""
    
    return {
        "β (beta)": {
            "范围": "0.01 - 0.5",
            "推荐": "0.1",
            "作用": "控制偏好强度,越大越激进"
        },
        
        "学习率": {
            "范围": "1e-7 - 1e-5", 
            "推荐": "5e-7",
            "作用": "DPO需要很小的学习率"
        },
        
        "批次大小": {
            "范围": "1 - 8",
            "推荐": "2-4", 
            "作用": "受显存限制,通常较小"
        },
        
        "训练轮数": {
            "范围": "1 - 5",
            "推荐": "3",
            "作用": "过多可能过拟合"
        }
    }

数据质量控制

def ensure_data_quality():
    """确保DPO数据质量"""
    
    quality_checks = [
        "偏好标注一致性检查",
        "回答长度合理性验证", 
        "内容安全性筛选",
        "语言流畅性评估",
        "事实准确性验证",
        "多样性平衡检查"
    ]
    
    return quality_checks

相关概念