多任务联合微调
概述
多任务联合微调是指在一个模型中同时训练多个相关任务,通过任务间的知识共享来提升整体性能。这种方法特别适用于任务间有相关性的场景,如同时进行生成式任务微调和分类任务微调。
多任务学习的优势
知识共享
- 共享表示:底层特征在多个任务间共享
- 正则化效果:多任务训练起到隐式正则化作用
- 数据效率:充分利用所有任务的数据
性能提升
- 泛化能力:多任务学习提升模型泛化性
- 鲁棒性:减少对单一任务的过拟合
- 资源效率:一个模型服务多个任务
多任务架构设计
共享编码器 + 多个任务头
import torch.nn as nn
class MultiTaskModel(nn.Module):
def __init__(self, base_model, task_configs):
super().__init__()
self.base_model = base_model
self.task_heads = nn.ModuleDict()
for task_name, config in task_configs.items():
if config['type'] == 'classification':
self.task_heads[task_name] = nn.Linear(
base_model.config.hidden_size,
config['num_classes']
)
elif config['type'] == 'generation':
self.task_heads[task_name] = nn.Linear(
base_model.config.hidden_size,
base_model.config.vocab_size
)
def forward(self, input_ids, attention_mask, task_name):
# 共享编码器
outputs = self.base_model(input_ids, attention_mask)
hidden_states = outputs.last_hidden_state
# 任务特定的头
if task_name in self.task_heads:
task_output = self.task_heads[task_name](hidden_states)
return task_output
else:
raise ValueError(f"Unknown task: {task_name}")
任务特定层设计
class TaskSpecificLayers(nn.Module):
def __init__(self, hidden_size, task_configs):
super().__init__()
self.task_layers = nn.ModuleDict()
for task_name, config in task_configs.items():
# 每个任务有自己的特定层
self.task_layers[task_name] = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, config['output_size'])
)
def forward(self, shared_features, task_name):
return self.task_layers[task_name](shared_features)
损失函数设计
加权多任务损失
def multi_task_loss(outputs, targets, task_weights):
total_loss = 0
task_losses = {}
for task_name, output in outputs.items():
if task_name == 'classification':
task_loss = F.cross_entropy(output, targets[task_name])
elif task_name == 'generation':
task_loss = F.cross_entropy(
output.view(-1, output.size(-1)),
targets[task_name].view(-1)
)
elif task_name == 'regression':
task_loss = F.mse_loss(output, targets[task_name])
task_losses[task_name] = task_loss
total_loss += task_weights[task_name] * task_loss
return total_loss, task_losses
动态权重调整
class DynamicWeightAveraging:
def __init__(self, num_tasks, temperature=2.0):
self.num_tasks = num_tasks
self.temperature = temperature
self.task_losses_history = []
def update_weights(self, task_losses):
self.task_losses_history.append(task_losses)
if len(self.task_losses_history) < 2:
# 初始权重相等
return {task: 1.0 for task in task_losses.keys()}
# 计算损失变化率
prev_losses = self.task_losses_history[-2]
loss_ratios = {}
for task in task_losses.keys():
ratio = task_losses[task] / prev_losses[task]
loss_ratios[task] = ratio
# 使用softmax计算权重
weights = {}
ratio_values = list(loss_ratios.values())
softmax_weights = F.softmax(torch.tensor(ratio_values) / self.temperature, dim=0)
for i, task in enumerate(loss_ratios.keys()):
weights[task] = softmax_weights[i].item()
return weights
数据处理策略
混合批次采样
class MultiTaskDataLoader:
def __init__(self, task_dataloaders, sampling_strategy='round_robin'):
self.task_dataloaders = task_dataloaders
self.sampling_strategy = sampling_strategy
self.task_iterators = {
task: iter(dataloader)
for task, dataloader in task_dataloaders.items()
}
def __iter__(self):
if self.sampling_strategy == 'round_robin':
return self._round_robin_sampling()
elif self.sampling_strategy == 'proportional':
return self._proportional_sampling()
def _round_robin_sampling(self):
task_names = list(self.task_dataloaders.keys())
task_idx = 0
while True:
task_name = task_names[task_idx]
try:
batch = next(self.task_iterators[task_name])
batch['task_name'] = task_name
yield batch
task_idx = (task_idx + 1) % len(task_names)
except StopIteration:
# 重新初始化迭代器
self.task_iterators[task_name] = iter(self.task_dataloaders[task_name])
break
任务数据格式统一
def unify_data_format(batch, task_name):
"""统一不同任务的数据格式"""
unified_batch = {
'input_ids': batch['input_ids'],
'attention_mask': batch['attention_mask'],
'task_name': task_name
}
if task_name == 'classification':
unified_batch['labels'] = batch['labels']
elif task_name == 'generation':
unified_batch['labels'] = batch['input_ids'] # 生成任务的标签就是输入
elif task_name == 'ner':
unified_batch['labels'] = batch['ner_labels']
return unified_batch
实战案例:智能客服系统
任务定义
# 定义多个相关任务
task_configs = {
'intent_classification': {
'type': 'classification',
'num_classes': 10, # 意图类别数
'weight': 1.0
},
'sentiment_analysis': {
'type': 'classification',
'num_classes': 3, # 正面、负面、中性
'weight': 0.8
},
'response_generation': {
'type': 'generation',
'vocab_size': 21128,
'weight': 1.2
},
'entity_extraction': {
'type': 'sequence_labeling',
'num_labels': 9, # BIO标注
'weight': 0.9
}
}
训练流程
def train_multi_task_model(model, multi_task_dataloader, task_configs):
optimizer = AdamW(model.parameters(), lr=2e-5)
weight_averager = DynamicWeightAveraging(len(task_configs))
for epoch in range(num_epochs):
epoch_losses = {task: 0 for task in task_configs.keys()}
for batch in multi_task_dataloader:
task_name = batch['task_name']
# 前向传播
outputs = model(
batch['input_ids'],
batch['attention_mask'],
task_name
)
# 计算损失
task_loss = compute_task_loss(outputs, batch['labels'], task_name)
epoch_losses[task_name] += task_loss.item()
# 反向传播
task_loss.backward()
optimizer.step()
optimizer.zero_grad()
# 更新任务权重
avg_losses = {task: loss/len(multi_task_dataloader)
for task, loss in epoch_losses.items()}
task_weights = weight_averager.update_weights(avg_losses)
print(f"Epoch {epoch}, Losses: {avg_losses}, Weights: {task_weights}")
多模态微调
图文联合模型
class MultiModalModel(nn.Module):
def __init__(self, text_encoder, image_encoder, fusion_dim):
super().__init__()
self.text_encoder = text_encoder
self.image_encoder = image_encoder
# 模态融合层
self.text_projection = nn.Linear(text_encoder.config.hidden_size, fusion_dim)
self.image_projection = nn.Linear(image_encoder.config.hidden_size, fusion_dim)
# 多任务头
self.classification_head = nn.Linear(fusion_dim, num_classes)
self.generation_head = nn.Linear(fusion_dim, vocab_size)
def forward(self, text_inputs, image_inputs, task_type):
# 编码文本和图像
text_features = self.text_encoder(**text_inputs).pooler_output
image_features = self.image_encoder(image_inputs).pooler_output
# 投影到共同空间
text_proj = self.text_projection(text_features)
image_proj = self.image_projection(image_features)
# 模态融合(这里使用简单的拼接)
fused_features = torch.cat([text_proj, image_proj], dim=-1)
# 任务特定输出
if task_type == 'classification':
return self.classification_head(fused_features)
elif task_type == 'generation':
return self.generation_head(fused_features)
多模态数据处理
def process_multimodal_data(text, image_path, tokenizer, image_processor):
# 处理文本
text_inputs = tokenizer(
text,
return_tensors='pt',
padding=True,
truncation=True
)
# 处理图像
image = Image.open(image_path)
image_inputs = image_processor(image, return_tensors='pt')
return {
'text_inputs': text_inputs,
'image_inputs': image_inputs
}
评估策略
任务特定评估
def evaluate_multi_task_model(model, eval_dataloaders, task_configs):
model.eval()
task_metrics = {}
for task_name, eval_dataloader in eval_dataloaders.items():
task_predictions = []
task_labels = []
for batch in eval_dataloader:
with torch.no_grad():
outputs = model(
batch['input_ids'],
batch['attention_mask'],
task_name
)
if task_configs[task_name]['type'] == 'classification':
predictions = torch.argmax(outputs, dim=-1)
task_predictions.extend(predictions.cpu().tolist())
task_labels.extend(batch['labels'].cpu().tolist())
# 计算任务特定指标
if task_configs[task_name]['type'] == 'classification':
accuracy = accuracy_score(task_labels, task_predictions)
f1 = f1_score(task_labels, task_predictions, average='weighted')
task_metrics[task_name] = {'accuracy': accuracy, 'f1': f1}
return task_metrics
优化技巧
梯度冲突处理
def resolve_gradient_conflicts(model, task_losses, task_weights):
"""处理多任务间的梯度冲突"""
# 计算每个任务的梯度
task_gradients = {}
for task_name, loss in task_losses.items():
model.zero_grad()
loss.backward(retain_graph=True)
# 收集梯度
task_grad = []
for param in model.parameters():
if param.grad is not None:
task_grad.append(param.grad.clone().flatten())
task_gradients[task_name] = torch.cat(task_grad)
# 使用PCGrad或其他方法解决冲突
# 这里简化为加权平均
final_gradient = torch.zeros_like(task_gradients[list(task_gradients.keys())[0]])
for task_name, grad in task_gradients.items():
final_gradient += task_weights[task_name] * grad
# 应用最终梯度
param_idx = 0
for param in model.parameters():
if param.grad is not None:
param_size = param.numel()
param.grad = final_gradient[param_idx:param_idx+param_size].view(param.shape)
param_idx += param_size