OFA-Image-Caption模型知识蒸馏入门训练一个轻量化的学生模型想用AI给图片生成描述但模型太大、推理太慢部署到手机或边缘设备上跑不动这可能是很多开发者遇到的头疼问题。直接用那些效果好的大模型比如OFA-Image-Caption参数动辄几十亿对计算资源要求很高。今天我们就来聊聊一个非常实用的解决方案知识蒸馏。简单来说知识蒸馏就像一位经验丰富的老师大模型在教一个聪明的学生小模型。老师不仅告诉学生标准答案这张图片里有“一只猫”还会分享自己的思考过程“我觉得有80%可能是猫15%可能是狸花猫5%可能光线问题看错了”。学生吸收了老师这些更丰富的“软知识”后往往能比单纯背标准答案学得更好、更接近老师的水平。这篇文章我就手把手带你走一遍这个过程。我们会用强大的OFA-Image-Caption模型作为教师生成高质量的“软标签”然后去训练一个结构简单、参数少得多的小型Transformer学生模型。最终目标是得到一个推理速度快、适合在资源受限环境下部署的轻量级图片描述模型。1. 知识蒸馏为什么“软标签”比“硬标签”更香在开始动手之前我们先花点时间搞清楚核心思想。这能帮你更好地理解后续每一步为什么要那么做。传统的模型训练我们用的是“硬标签”。比如一张猫的图片它的标签就是“一只猫”。模型训练的目标就是让自己的输出尽可能逼近这个唯一的正确答案。这种方式简单直接但有个问题它丢弃了模型在判断过程中丰富的概率信息。而知识蒸馏引入了“软标签”的概念。还是那张猫的图片我们让训练好的大模型教师模型去预测它输出的可能是一个概率分布[“一只猫”: 0.85, “一只狸花猫”: 0.1, “一只蜷缩的动物”: 0.05]。这个分布就是软标签它包含了教师模型的“知识”——它不仅知道最可能是猫还知道和“狸花猫”这类近义词的关联程度。那么用软标签训练学生模型好在哪里呢信息更丰富学生模型能学到类别之间的相似性关系猫和狸花猫是相近的而不仅仅是非此即彼的判别。这通常能让模型泛化能力更强。正则化效果软标签的分布相对平滑可以看作一种有效的正则化能防止学生模型对训练数据过拟合。迁移“暗知识”教师模型在大量数据上学到的一些隐含的、难以用硬标签表述的模式可以通过软标签传递给学生。在我们的场景里OFA教师模型在图文匹配和生成上能力很强它生成的描述概率分布蕴含了对图片内容、语法、常见搭配的深刻理解。我们的学生模型目标就是“偷师”这份理解并用更小的模型体量去近似它。2. 环境搭建与数据准备工欲善其事必先利其器。我们先来把环境和数据准备好。2.1 安装必要的软件包我们需要用到一些深度学习相关的库。建议创建一个新的Python虚拟环境来管理依赖。# 使用conda创建环境可选 conda create -n ofa_distill python3.8 conda activate ofa_distill # 安装PyTorch请根据你的CUDA版本访问PyTorch官网获取对应命令 # 例如对于CUDA 11.3 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu113 # 安装Transformer相关库和OFA pip install transformers pip install fairseq # 安装用于图像处理的库 pip install Pillow pip install timm2.2 准备教师模型OFA-Image-CaptionOFAOne-For-All是一个统一的跨模态预训练模型我们这里使用它的图片描述版本。我们可以通过transformers库方便地加载。from transformers import OFATokenizer, OFAModel from PIL import Image # 加载预训练的OFA模型和分词器 model_name OFA-Sys/ofa-base # 也可以选择 ofa-large 等更大模型 tokenizer OFATokenizer.from_pretrained(model_name) teacher_model OFAModel.from_pretrained(model_name, use_cacheFalse) teacher_model.eval() # 设置为评估模式不进行梯度更新 # 将模型移动到GPU如果可用 import torch device torch.device(cuda if torch.cuda.is_available() else cpu) teacher_model.to(device) print(f教师模型加载完毕运行在 {device} 上。)2.3 准备训练数据集我们需要一个带有图片和对应描述的数据集。这里以常用的COCO Captions数据集为例。你可以从官网下载或者使用torchvision或datasets库来加载。为了简化流程我们假设你已经将COCO数据集的图片和标注文件如annotations/captions_train2017.json放在了本地目录./data/coco下。下面的代码演示如何创建一个简单的数据集类用于加载图片和对应的硬标签原始描述。import json import os from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms class COCOCaptionDataset(Dataset): def __init__(self, image_dir, annotation_path, tokenizer, max_length20): self.image_dir image_dir self.tokenizer tokenizer self.max_length max_length # 加载标注文件 with open(annotation_path, r) as f: annotations json.load(f) # 构建图像id到文件名的映射以及图像id到描述列表的映射 self.id_to_filename {img[id]: img[file_name] for img in annotations[images]} self.id_to_captions {} for ann in annotations[annotations]: img_id ann[image_id] self.id_to_captions.setdefault(img_id, []).append(ann[caption]) # 创建样本列表每个样本是图像id 其中一个描述 self.samples [] for img_id, captions in self.id_to_captions.items(): for caption in captions[:1]: # 这里为了简单每张图片只取第一个描述 self.samples.append((img_id, caption)) # 图像预处理变换 self.transform transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean[0.5, 0.5, 0.5], std[0.5, 0.5, 0.5]) ]) def __len__(self): return len(self.samples) def __getitem__(self, idx): img_id, caption self.samples[idx] img_path os.path.join(self.image_dir, self.id_to_filename[img_id]) image Image.open(img_path).convert(RGB) image self.transform(image) # 对硬标签描述文本进行编码 hard_label_ids self.tokenizer.encode( caption, max_lengthself.max_length, paddingmax_length, truncationTrue, return_tensorspt ).squeeze(0) # 去掉batch维度 return { image: image, # 图像张量 hard_label_ids: hard_label_ids, # 硬标签的token id caption_text: caption # 保留原始文本方便查看 } # 初始化数据集 train_dataset COCOCaptionDataset( image_dir./data/coco/train2017, annotation_path./data/coco/annotations/captions_train2017.json, tokenizertokenizer ) print(f训练数据集大小: {len(train_dataset)})3. 生成“软标签”与设计学生模型接下来是核心步骤让教师模型给我们生成软标签并设计一个待训练的学生模型。3.1 利用教师模型生成软标签我们不会在训练时动态生成软标签那样太慢而是预先用教师模型对整个训练集进行推理把生成的软标签保存下来。这个过程通常称为“离线蒸馏”。我们需要修改一下数据集让它能返回教师模型生成的软标签概率分布。由于文本生成是序列问题教师模型在每个时间步都会输出一个在整个词表上的概率分布。保存完整的分布矩阵会非常庞大一个常见的简化方法是保存教师模型在每个时间步对硬标签词的预测概率以及整个序列的生成概率通过beam search等。这里我们采用一个更实用的方法保存教师模型通过beam search生成的前k个候选序列及其概率。学生模型可以学习到“哪些描述是老师认为好的”。def generate_teacher_soft_labels(batch_images, teacher_model, tokenizer, device, num_beams5): 为一批图像生成教师模型的软标签top-k候选序列及分数 teacher_model.eval() with torch.no_grad(): # 将图像移动到设备 batch_images batch_images.to(device) # 使用beam search生成描述 inputs {input_ids: None, patch_images: batch_images} # 注意OFA模型的生成API可能需要根据具体版本调整这里为示意 # 实际请参考OFA官方文档或transformers的GenerativeMixin generated teacher_model.generate( **inputs, num_beamsnum_beams, max_length20, early_stoppingTrue, num_return_sequencesnum_beams, # 返回多个序列 output_scoresTrue, # 输出分数 return_dict_in_generateTrue ) # generated.sequences 形状: (batch_size * num_beams, seq_len) # generated.sequences_scores 形状: (batch_size * num_beams,) sequences generated.sequences sequences_scores generated.sequences_scores # 将分数转换为概率softmax over beams # 对每个样本的num_beams个分数做softmax batch_size batch_images.size(0) beam_scores sequences_scores.view(batch_size, num_beams) beam_probs torch.softmax(beam_scores, dim-1) # 形状: (batch_size, num_beams) return sequences, beam_probs # 示例处理一个批次并保存结果此步骤需要对整个数据集循环耗时较长建议保存结果到文件 # 这里仅展示逻辑 dataloader torch.utils.data.DataLoader(train_dataset, batch_size4, shuffleFalse) all_teacher_sequences [] all_teacher_beam_probs [] for batch in dataloader: images batch[image] seq, probs generate_teacher_soft_labels(images, teacher_model, tokenizer, device) all_teacher_sequences.append(seq.cpu()) all_teacher_beam_probs.append(probs.cpu()) # 处理一定批次后保存到文件避免内存不足 # torch.save(...)在实际操作中你需要编写脚本离线处理整个训练集将(image_id, teacher_topk_sequences, teacher_beam_probs)保存下来然后创建一个新的数据集类来加载这些预生成的软标签。3.2 设计轻量级学生模型我们的学生模型需要比OFA小很多。一个典型的选择是设计一个更浅或更窄的Transformer解码器。这里为了演示我们使用一个非常简单的、只有几层的小型Transformer解码器。在实际应用中你可以选择T5-small、DistilGPT2等现成的轻量模型作为起点。import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class SmallTransformerConfig(PretrainedConfig): model_type small_transformer def __init__( self, vocab_size50265, # 与OFA词表一致 max_position_embeddings1024, d_model256, # 远小于OFA-base的768 nhead8, num_decoder_layers4, # 层数减少 dim_feedforward1024, dropout0.1, **kwargs ): super().__init__(**kwargs) self.vocab_size vocab_size self.max_position_embeddings max_position_embeddings self.d_model d_model self.nhead nhead self.num_decoder_layers num_decoder_layers self.dim_feedforward dim_feedforward self.dropout dropout class SmallTransformerForCaptioning(PreTrainedModel): config_class SmallTransformerConfig def __init__(self, config): super().__init__(config) self.config config # 词嵌入层 self.embedding nn.Embedding(config.vocab_size, config.d_model) # 位置编码这里使用可学习的位置编码更简单 self.pos_encoder nn.Embedding(config.max_position_embeddings, config.d_model) # 图像特征投影层将图像CNN特征投影到d_model空间 self.image_proj nn.Linear(512, config.d_model) # 假设图像特征维度是512 # Transformer解码器层 decoder_layer nn.TransformerDecoderLayer( d_modelconfig.d_model, nheadconfig.nhead, dim_feedforwardconfig.dim_feedforward, dropoutconfig.dropout, activationgelu, batch_firstTrue # 使用batch_first更直观 ) self.transformer_decoder nn.TransformerDecoder(decoder_layer, num_layersconfig.num_decoder_layers) # 输出层 self.output_layer nn.Linear(config.d_model, config.vocab_size) # 损失函数 self.loss_fct nn.CrossEntropyLoss(ignore_indextokenizer.pad_token_id) self.init_weights() def init_weights(self): # 简单的参数初始化 for p in self.parameters(): if p.dim() 1: nn.init.xavier_uniform_(p) def forward(self, image_features, tgt_ids, tgt_maskNone): image_features: (batch_size, feat_seq_len, feat_dim) 例如CNN特征 tgt_ids: (batch_size, seq_len) 目标描述token id训练时右移一位作为decoder输入 tgt_mask: (seq_len, seq_len) 防止看到未来词的mask batch_size, seq_len tgt_ids.size() # 1. 处理图像特征 image_feats_proj self.image_proj(image_features) # (batch, feat_seq_len, d_model) # 2. 处理文本嵌入 tgt_emb self.embedding(tgt_ids) # (batch, seq_len, d_model) positions torch.arange(seq_len, devicetgt_ids.device).unsqueeze(0).expand(batch_size, -1) pos_emb self.pos_encoder(positions) tgt_emb tgt_emb pos_emb # 3. 通过Transformer解码器 # 将图像特征作为memory文本嵌入作为tgt decoder_output self.transformer_decoder( tgttgt_emb, memoryimage_feats_proj, tgt_masktgt_mask, memory_maskNone ) # (batch, seq_len, d_model) # 4. 预测词表概率 logits self.output_layer(decoder_output) # (batch, seq_len, vocab_size) return logits # 初始化学生模型 student_config SmallTransformerConfig(vocab_sizetokenizer.vocab_size) student_model SmallTransformerForCaptioning(student_config) student_model.to(device) print(f学生模型参数量: {sum(p.numel() for p in student_model.parameters()):,})注意上面的学生模型是一个极简示例省略了图像编码器如CNN。在实际操作中你需要一个固定的CNN如ResNet来提取图像特征或者将一个小型图像编码器与学生解码器一起训练。4. 设计损失函数与训练流程知识蒸馏的精髓在于损失函数的设计。我们需要让学生模型同时向硬标签和教师模型的软标签学习。4.1 组合损失函数常见的组合是蒸馏损失KL散度学生任务损失交叉熵。蒸馏损失 (L_KD)让学生模型输出的概率分布经过softmax和温度系数T软化去逼近教师模型的软标签分布。这通常使用KL散度来衡量。学生损失 (L_CE)让学生模型的输出去匹配真实的硬标签标准答案。这是传统的交叉熵损失。总损失是两者的加权和L_total α * L_KD (1-α) * L_CEdef knowledge_distillation_loss(student_logits, teacher_logits, hard_labels, temperature4.0, alpha0.7): 计算知识蒸馏损失。 student_logits: 学生模型输出的原始分数 (batch, seq_len, vocab_size) teacher_logits: 教师模型输出的原始分数 (需要与student_logits形状一致) hard_labels: 硬标签的token id (batch, seq_len) temperature: 温度系数软化概率分布 alpha: 蒸馏损失的权重 # 1. 计算蒸馏损失 (KL散度) # 对logits应用温度系数并取softmax student_soft nn.functional.log_softmax(student_logits / temperature, dim-1) teacher_soft nn.functional.softmax(teacher_logits / temperature, dim-1) # KL散度: sum( teacher_soft * log(teacher_soft/student_soft) ) sum(teacher_soft * log_teacher) - sum(teacher_soft * log_student) # 我们只计算后一项与student相关因为前一项是常数 kd_loss nn.functional.kl_div(student_soft, teacher_soft, reductionbatchmean) * (temperature ** 2) # 乘以 temperature^2 是因为在计算梯度时需要补偿温度缩放的影响 # 2. 计算学生任务损失 (交叉熵) ce_loss nn.functional.cross_entropy( student_logits.view(-1, student_logits.size(-1)), hard_labels.view(-1), ignore_indextokenizer.pad_token_id ) # 3. 组合损失 total_loss alpha * kd_loss (1 - alpha) * ce_loss return total_loss, kd_loss, ce_loss4.2 完整的训练循环假设我们已经有了一个准备好的数据集它每次能提供(图像特征, 硬标签token id, 教师软标签logits)。下面是一个简化的训练循环框架。import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm # 用于显示进度条 # 假设我们已经有了封装好的蒸馏数据集 DistillationDataset # train_distill_dataset DistillationDataset(...) train_loader DataLoader(train_distill_dataset, batch_size32, shuffleTrue, num_workers4) # 定义优化器 optimizer optim.AdamW(student_model.parameters(), lr5e-5) # 训练轮数 num_epochs 10 student_model.train() for epoch in range(num_epochs): epoch_loss 0 epoch_kd_loss 0 epoch_ce_loss 0 progress_bar tqdm(train_loader, descfEpoch {epoch1}/{num_epochs}) for batch in progress_bar: # 获取数据 img_feats batch[image_features].to(device) hard_label_ids batch[hard_label_ids].to(device) teacher_logits batch[teacher_logits].to(device) # 预生成的教师logits # 前向传播学生模型 # 注意训练时decoder的输入是目标序列右移一位自回归训练 decoder_input_ids hard_label_ids[:, :-1] # 去掉最后一个token decoder_target_ids hard_label_ids[:, 1:] # 去掉第一个token通常是起始符 teacher_logits teacher_logits[:, :-1, :] # 教师logits也做相应调整 student_logits student_model(img_feats, decoder_input_ids) # 计算损失 total_loss, kd_loss, ce_loss knowledge_distillation_loss( student_logits, teacher_logits, decoder_target_ids, temperature4.0, alpha0.7 ) # 反向传播与优化 optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm1.0) # 梯度裁剪 optimizer.step() # 记录损失 epoch_loss total_loss.item() epoch_kd_loss kd_loss.item() epoch_ce_loss ce_loss.item() # 更新进度条描述 progress_bar.set_postfix({ Loss: total_loss.item(), KD: kd_loss.item(), CE: ce_loss.item() }) # 打印本轮平均损失 avg_loss epoch_loss / len(train_loader) avg_kd epoch_kd_loss / len(train_loader) avg_ce epoch_ce_loss / len(train_loader) print(fEpoch {epoch1} 结束: Avg Loss {avg_loss:.4f}, Avg KD {avg_kd:.4f}, Avg CE {avg_ce:.4f}) # 可以在这里保存模型检查点 # torch.save(student_model.state_dict(), fstudent_model_epoch_{epoch1}.pt)5. 模型评估与部署建议训练完成后我们需要看看这个“学生”学得怎么样。5.1 评估学生模型评估生成式任务常用BLEU、ROUGE、CIDEr、SPICE等指标。我们可以用nlg-eval等库方便计算。这里简单演示如何用训练好的模型进行推理生成。def generate_caption(student_model, image_features, tokenizer, max_len20): 用学生模型为单张图片生成描述 student_model.eval() device next(student_model.parameters()).device # 准备起始token start_token tokenizer.bos_token_id decoder_input torch.tensor([[start_token]], devicedevice) generated_ids [start_token] with torch.no_grad(): image_feats image_features.unsqueeze(0).to(device) # 增加batch维度 for _ in range(max_len): logits student_model(image_feats, decoder_input) # 取最后一个时间步的logits并选择概率最大的词 next_token_logits logits[:, -1, :] next_token_id torch.argmax(next_token_logits, dim-1).item() generated_ids.append(next_token_id) decoder_input torch.tensor([generated_ids], devicedevice) if next_token_id tokenizer.eos_token_id: break # 将token id解码为文本 caption tokenizer.decode(generated_ids[1:], skip_special_tokensTrue) # 跳过起始符 return caption # 示例加载一张测试图片并生成描述 test_image_path your_test_image.jpg test_image Image.open(test_image_path).convert(RGB) # 1. 提取图像特征使用与训练时相同的CNN # image_features cnn_encoder(preprocess(test_image)) # 2. 生成描述 # predicted_caption generate_caption(student_model, image_features, tokenizer) # print(f生成的描述: {predicted_caption})将学生模型在测试集上的生成结果与教师模型、以及原始硬标签进行对比计算上述自动化指标就能定量评估蒸馏效果。通常学生模型在指标上会略低于教师模型但远高于同参数规模下只用硬标签训练的模型。5.2 部署建议得到轻量化的学生模型后部署就简单多了模型导出可以使用torch.jit.trace或torch.jit.script将PyTorch模型转换为TorchScript或者使用ONNX格式导出以获得更好的跨平台推理性能。优化加速利用TensorRT、OpenVINO、Core ML苹果设备或NCNN、MNN移动端等推理框架对模型进行进一步的优化、量化和加速。端侧集成将优化后的模型集成到Android通过PyTorch Mobile或TFLite、iOS或嵌入式设备中。由于模型体积小、计算量低实时生成图片描述成为可能。获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。