完整实战BERT原理、数据预处理、模型训练、早停、混合精度加速、Checkpoint断点续训、多指标评估及FastAPI生产级部署完整代码下载包含数据集https://pan.baidu.com/s/1IHIXrC2GET_aLO1Lhh5Z7Q?pwdyiix在电商平台商品发布往往需要手动填写品牌、品类等大量信息既耗时又容易出错。智能商品录入系统应运而生——它能根据商家输入的商品标题自动预测并填写正确的分类大幅提升上架效率。本文将从零开始手把手带你实现这样一个系统基于BERT中文预训练模型完成文本多分类任务并最终通过FastAPI提供RESTful服务。1. 项目概述1.1 业务背景在电商场景中商品标题通常包含丰富的描述信息例如“唯美小清新连衣裙 吊带 短裙 无袖 网纱”。人工判断其所属类目如“服装鞋包 女装/女士精品 连衣裙”需要经验和时间。本项目的目标是输入商品标题自动输出其对应的商品分类标签。从机器学习角度看这是一个典型的文本多分类问题。我们将使用BERT模型作为编码器在其顶部添加一个全连接分类层通过微调Fine-tuning使模型适应商品标题分类任务。1.2 技术选型技术点核心价值实现方式BERT微调利用预训练语言模型的语义理解能力加载bert-base-chinese添加线性分类头CLS Token分类获取句子级别的向量表示取BERT输出的[CLS]向量作为特征分词与数据预处理将文本转换为模型可接受的输入格式AutoTokenizerdatasets库混合精度训练(AMP)减少显存占用约50%加速训练1.5~2倍torch.autocastGradScaler早停机制防止过拟合自动保存最优模型监控验证集损失连续N轮不下降则停止Checkpoint断点续训训练意外中断后可恢复保存模型、优化器、scaler、epoch状态多分类评估指标全面衡量模型性能准确率、精确率、召回率、F1宏平均FastAPI部署提供高性能HTTP预测接口FastAPIUvicorn自动生成API文档CLI统一入口方便执行不同任务argparse实现命令行工具2. 环境准备与项目结构2.1 Conda环境搭建创建独立的Python环境推荐Python 3.12避免依赖冲突conda create -n product-classify python3.12 conda activate product-classify2.2 依赖安装核心依赖说明PyTorch深度学习框架支持GPU加速。TransformersHugging Face提供的预训练模型库用于加载BERT和分词器。Datasets高效加载和处理大规模数据集。scikit-learn提供准确率、精确率、召回率、F1分数等评估指标。TensorBoard可视化训练过程中的损失曲线。tqdm显示训练进度条。FastAPI Uvicorn构建和部署高性能API服务。安装命令# 根据CUDA版本选择合适的PyTorch以CUDA 12.6为例 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 # 安装其余依赖 pip install transformers datasets scikit-learn tensorboard tqdm jupyter fastapi uvicorn2.3 项目目录结构product_classify/ ├── data/ # 原始数据train.txt, valid.txt, test.txt ├── logs/ # TensorBoard日志 ├── models/ # 保存模型权重model.pt, checkpoint.pt ├── pretrained/ # 预训练模型文件bert-base-chinese ├── src/ │ ├── configuration/ # 配置文件路径、超参数 │ ├── model/ # 模型定义classifier.py │ ├── preprocess/ # 数据预处理process.py, dataset.py │ ├── runner/ # 训练、评估、预测脚本 │ └── web/ # FastAPI服务app.py, routers.py, schemas.py └── main.py # CLI入口3. 数据预处理3.1 数据集说明本实验使用的商品标题分类数据集来自百度AI Studio包含训练集、验证集和测试集格式为TSV制表符分隔。示例数据如下labeltext_a母婴好奇心钻装纸尿裤 L40片 9-14kg蔬菜基地玉米酒饮冲调240ML*15 养元2430 六个核桃玩具911-267 遥控车乳品125ML*4 伊利臻浓牛奶3.2 BERT分词原理BERT使用WordPiece子词分词算法将汉字和常见子词单元拆分为token。例如“智能手机”可能被拆分为[智, 能, 手, 机]。每个token对应一个唯一的IDinput_ids同时生成attention_mask用于指示哪些位置是真实token、哪些是填充padding。此外BERT要求输入序列以[CLS]开头以[SEP]分隔或结尾。为什么要使用BERT分词器有效解决未登录词OOV问题。灵活处理新词、品牌名、型号等非常规表达。配合内置的vocab.txt共21128个token深度捕捉中文语义。3.3 预处理流程与代码预处理共分五步加载数据使用datasets.load_dataset读取CSV/TSV文件。过滤空值移除text_a或label为空的样本。标签编码提取所有唯一标签转换为HuggingFace的ClassLabel类型。分词与编码调用tokenizer将文本转换为input_ids和attention_mask并统一填充/截断到固定长度如128。保存为磁盘格式使用save_to_disk保存处理后的数据集。# src/preprocess/process.py import datasets from datasets import ClassLabel from transformers import AutoTokenizer from configuration import config # 集中管理路径和超参数 def process(): # 1. 加载数据 dataset_dic datasets.load_dataset( csv, data_files{ train: str(config.RAW_DATA_DIR / train.txt), test: str(config.RAW_DATA_DIR / test.txt), valid: str(config.RAW_DATA_DIR / valid.txt) }, delimiter\t ) # 2. 过滤空值 dataset_dic dataset_dic.filter( lambda x: x[text_a] is not None and x[label] is not None ) # 3. 标签处理提取所有标签并转为ClassLabel all_labels sorted(set(dataset_dic[train][label])) dataset_dic dataset_dic.cast_column(label, ClassLabel(namesall_labels)) # 4. 加载BERT中文分词器 tokenizer AutoTokenizer.from_pretrained( str(config.PRE_TRAINED_DIR / bert-base-chinese) ) def tokenize(example): # 对单条文本进行分词、填充、截断 encoded tokenizer( example[text_a], truncationTrue, paddingmax_length, max_lengthconfig.SEQ_LEN # 例如128 ) example[input_ids] encoded[input_ids] example[attention_mask] encoded[attention_mask] return example # 5. 批量应用分词并删除原始文本列 dataset_dic dataset_dic.map(tokenize, batchedTrue, remove_columns[text_a]) # 6. 保存到磁盘 dataset_dic[train].save_to_disk(str(config.PROCESSED_DATA_DIR / train)) dataset_dic[test].save_to_disk(str(config.PROCESSED_DATA_DIR / test)) dataset_dic[valid].save_to_disk(str(config.PROCESSED_DATA_DIR / valid))3.4 构建DataLoader为了高效批次训练我们将磁盘上的数据集加载为PyTorch的Dataset并封装成DataLoader# src/preprocess/dataset.py from enum import StrEnum from datasets import load_from_disk from torch.utils.data import DataLoader from configuration import config class DataType(StrEnum): TRAIN train TEST test VALID valid def get_dataset(data_type: DataType): dataset load_from_disk(str(config.PROCESSED_DATA_DIR / data_type)) # 指定返回的列并转为PyTorch张量格式 dataset.set_format(torch, columns[input_ids, attention_mask, label]) return dataset def get_dataloader(data_type: DataType DataType.TRAIN): dataset get_dataset(data_type) shuffle (data_type DataType.TRAIN) # 仅训练集打乱 return DataLoader(dataset, batch_sizeconfig.BATCH_SIZE, shuffleshuffle)4. 模型定义4.1 BERT微调原理BERTBidirectional Encoder Representations from Transformers是Google在2018年提出的预训练语言模型其核心架构包含12层Transformer编码器隐藏层维度为768共约1.1亿参数。BERT的强大之处在于其独特的双向预训练机制MLM掩码语言模型随机遮蔽输入句子中15%的汉字让模型根据上下文预测被遮蔽的字词实现真正的双向语义建模。NSP下一句预测判断两个句子是否连续出现增强模型对句间关系的理解。对于文本分类任务BERT会在输入序列开头添加一个特殊的[CLS] token。该token经过所有Transformer层后其输出向量聚合了整个序列的语义信息因此常被用作句子的整体表示再通过一个线性分类层完成类别预测。4.2 分类头设计我们只需要在这个768维的向量上接一个线性层nn.Linear(768, num_classes)即可得到每个类别的logits。微调策略有两种冻结BERTfreeze_bertTrue只训练分类头适合小数据集训练快但效果可能欠佳。全参数微调freeze_bertFalseBERT的全部参数也参与训练能够更好地适应领域数据但需要更多显存和时间。本项目中由于商品标题领域词汇较特殊建议使用全参数微调。4.3 代码实现# src/model/classifier.py import torch.nn as nn from transformers import AutoModel from configuration import config class BertTitleClassifier(nn.Module): 基于BERT的商品标题分类器 结构BERT编码器 线性分类头 def __init__(self, freeze_bert: bool False): super().__init__() # 加载预训练的BERT中文模型 self.bert AutoModel.from_pretrained( str(config.PRE_TRAINED_DIR / bert-base-chinese) ) # 分类头768 - num_classes self.classifier nn.Linear( self.bert.config.hidden_size, # 768 config.NUM_CLASSES ) # 可选冻结BERT参数 if freeze_bert: for param in self.bert.parameters(): param.requires_grad False def forward(self, input_ids, attention_maskNone): 前向传播 input_ids: (batch_size, seq_len) attention_mask: (batch_size, seq_len) 返回 logits: (batch_size, num_classes) outputs self.bert(input_idsinput_ids, attention_maskattention_mask) # 取[CLS] token的向量第一个位置 cls_output outputs.last_hidden_state[:, 0, :] # (batch_size, 768) logits self.classifier(cls_output) # (batch_size, num_classes) return logits5. 模型训练5.1 训练流程训练脚本需要完成以下任务加载数据通过get_dataloader获取训练和验证数据加载器。定义优化器与损失函数使用Adam优化器和交叉熵损失。训练循环每个epoch执行前向传播、损失计算、反向传播和参数更新。日志记录使用TensorBoard记录训练和验证损失。保存最佳模型保存验证损失最低的模型权重。5.2 基础训练代码# src/runner/train.py import time import torch from torch import nn from torch.optim import Adam from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from configuration import config from model.classifier import BertTitleClassifier from preprocess.dataset import get_dataloader, DataType def run_one_epoch(model, dataloader, device, loss_fn, optimizerNone, is_trainTrue): 执行一个epoch的训练或验证 - is_trainTrue: 训练模式更新梯度 - is_trainFalse: 验证模式只计算损失 epoch_loss 0 model.train() if is_train else model.eval() with torch.set_grad_enabled(is_train): desc 训练 if is_train else 验证 for batch in tqdm(dataloader, descdesc): input_ids batch[input_ids].to(device) attention_mask batch[attention_mask].to(device) labels batch[label].to(device) outputs model(input_ids, attention_mask) loss loss_fn(outputs, labels) if is_train: optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss loss.item() return epoch_loss / len(dataloader) def train(): device torch.device(cuda if torch.cuda.is_available() else cpu) print(f设备: {device}) # TensorBoard日志 log_dir config.LOGS_DIR / time.strftime(%Y%m%d-%H%M%S) writer SummaryWriter(log_dirlog_dir) model BertTitleClassifier(freeze_bertFalse).to(device) train_loader get_dataloader(DataType.TRAIN) valid_loader get_dataloader(DataType.VALID) loss_fn nn.CrossEntropyLoss() optimizer Adam(model.parameters(), lrconfig.LEARNING_RATE) best_loss float(inf) for epoch in range(1, config.EPOCHS 1): print(f Epoch {epoch} ) train_loss run_one_epoch(model, train_loader, device, loss_fn, optimizer, is_trainTrue) valid_loss run_one_epoch(model, valid_loader, device, loss_fn, is_trainFalse) print(f训练集 loss: {train_loss:.4f} | 验证集 loss: {valid_loss:.4f}) writer.add_scalar(Loss/train, train_loss, epoch) writer.add_scalar(Loss/valid, valid_loss, epoch) if valid_loss best_loss: best_loss valid_loss torch.save(model.state_dict(), config.MODELS_DIR / model.pt) print(f✓ 保存最佳模型验证损失: {best_loss:.4f}) writer.close()6. 模型预测6.1 推理原理预测时模型处于eval()模式关闭Dropout和BatchNorm的训练行为。使用torch.no_grad()禁用梯度计算加速推理并节省显存。对于单条文本同样需要经过分词、编码、填充/截断然后输入模型得到logits最后取argmax获得类别索引再通过标签映射转换为类别名称。6.2 代码实现# src/runner/predict.py import torch from datasets import load_from_disk from transformers import AutoTokenizer from configuration import config from model.classifier import BertTitleClassifier def predict_batch(model, input_ids, attention_mask): 批量预测返回类别索引 with torch.no_grad(): outputs model(input_ids, attention_mask) preds torch.argmax(outputs, dim1) return preds def predict_text(text, model, tokenizer, device, label_feature): 单条文本预测返回 (类别ID, 类别名称) encoded tokenizer( text, return_tensorspt, paddingmax_length, truncationTrue, max_lengthconfig.SEQ_LEN ) input_ids encoded[input_ids].to(device) attention_mask encoded[attention_mask].to(device) preds predict_batch(model, input_ids, attention_mask) pred_id preds[0].item() pred_label label_feature.int2str(pred_id) return pred_id, pred_label def run_predict(): device torch.device(cuda if torch.cuda.is_available() else cpu) model BertTitleClassifier().to(device) model.load_state_dict(torch.load(config.MODELS_DIR / model.pt, map_locationdevice)) model.eval() tokenizer AutoTokenizer.from_pretrained(str(config.PRE_TRAINED_DIR / bert-base-chinese)) label_feature load_from_disk(str(config.PROCESSED_DATA_DIR / train)).features[label] print(请输入商品标题输入q退出) while True: text input( ).strip() if text.lower() in [q, quit]: break if not text: continue pred_id, pred_label predict_text(text, model, tokenizer, device, label_feature) print(f预测类别: {pred_label} (ID: {pred_id}))7. 模型评估7.1 多分类评估指标对于多分类任务单纯看准确率Accuracy往往不够全面。我们需要从多个维度评估模型性能指标计算公式业务含义准确率(TPTN) / 总数所有预测中正确的比例精确率TP / (TPFP)预测为正例的样本中实际为正的比例召回率TP / (TPFN)实际为正例的样本中被正确预测的比例F1分数2×P×R/(PR)精确率和召回率的调和平均在多分类场景下通常使用宏平均Macro Average方式计算上述指标——先对每个类别分别计算指标再取算术平均这样每个类别获得相同的权重不受样本数量影响。7.2 代码实现# src/runner/evaluate.py from enum import StrEnum import torch from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from tqdm import tqdm from configuration import config from preprocess.dataset import get_dataloader, DataType from model.classifier import BertTitleClassifier class Metric(StrEnum): ACCURACY accuracy PRECISION precision RECALL recall F1 f1 def evaluate_model(model, dataloader, device, metrics): model.eval() all_preds, all_labels [], [] for batch in tqdm(dataloader, desc评估): input_ids batch[input_ids].to(device) attention_mask batch[attention_mask].to(device) labels batch[label].to(device) with torch.no_grad(): outputs model(input_ids, attention_mask) preds torch.argmax(outputs, dim1) all_preds.extend(preds.tolist()) all_labels.extend(labels.tolist()) results {} if Metric.ACCURACY in metrics: results[Metric.ACCURACY] accuracy_score(all_labels, all_preds) if Metric.F1 in metrics: results[Metric.F1] f1_score(all_labels, all_preds, averagemacro, zero_division0) if Metric.PRECISION in metrics: results[Metric.PRECISION] precision_score(all_labels, all_preds, averagemacro, zero_division0) if Metric.RECALL in metrics: results[Metric.RECALL] recall_score(all_labels, all_preds, averagemacro, zero_division0) return results def run_evaluation(): device torch.device(cuda if torch.cuda.is_available() else cpu) model BertTitleClassifier().to(device) model.load_state_dict(torch.load(config.MODELS_DIR / model.pt, map_locationdevice)) dataloader get_dataloader(DataType.TEST) metrics [Metric.ACCURACY, Metric.F1, Metric.PRECISION, Metric.RECALL] results evaluate_model(model, dataloader, device, metrics) print( 评估结果 ) for name, value in results.items(): print(f{name}: {value:.4f})8. 训练优化技术8.1 早停机制技术点早停Early Stopping是一种防止过拟合的技术。当模型在验证集上的损失停止下降且持续若干个epoch没有改进时训练将提前终止并保存当前最佳模型。这避免了过度训练导致的泛化能力下降。实现思路维护一个计数器counter每当验证损失没有优于历史最佳时加1一旦达到patience阈值触发早停。同时保存最佳模型。# src/runner/train.py 中添加 EarlyStopping 类 class EarlyStopping: def __init__(self, patience2, pathNone): self.patience patience self.counter 0 self.best_score None self.early_stop False self.path path def __call__(self, val_loss, model): score -val_loss # 损失越小越好取负后越大越好 if self.best_score is None or score self.best_score: self.best_score score self.counter 0 self.save_model(model) else: self.counter 1 if self.counter self.patience: self.early_stop True def save_model(self, model): torch.save(model.state_dict(), self.path)8.2 混合精度训练技术点混合精度训练AMP是指部分操作使用float16半精度以减少显存占用和加快计算速度而关键操作如损失计算、归一化仍使用float32以保证数值稳定性。PyTorch通过torch.autocast自动选择精度GradScaler动态缩放损失以防止梯度下溢。优势显存占用减少约50%。训练速度提升1.5~2倍尤其在GPU上。几乎不影响模型精度。代码集成在训练循环中使用autocast上下文并配合GradScaler进行反向传播。from torch.cuda.amp import autocast, GradScaler scaler GradScaler() # 在训练epoch中 with autocast(): outputs model(input_ids, attention_mask) loss loss_fn(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad()8.3 检查点机制技术点Checkpointing检查点机制是指在训练中定期保存模型权重、优化器状态、当前epoch、随机数状态等关键信息以便训练意外中断如断电、OOM后能够从断点恢复无需重新开始。保存内容model.state_dict()optimizer.state_dict()scaler.state_dict()如果使用AMPepoch当前轮数early_stopping的状态best_score, counter恢复逻辑训练开始前检测检查点文件是否存在若存在则加载状态并设置start_epoch checkpoint[epoch] 1。# 保存检查点 checkpoint { epoch: epoch, model_state_dict: model.state_dict(), optimizer_state_dict: optimizer.state_dict(), scaler_state_dict: scaler.state_dict(), early_stopping_best_score: early_stopping.best_score, early_stopping_counter: early_stopping.counter, } torch.save(checkpoint, config.MODELS_DIR / checkpoint.pt) # 恢复检查点 if checkpoint_path.exists(): checkpoint torch.load(checkpoint_path, map_locationdevice) model.load_state_dict(checkpoint[model_state_dict]) optimizer.load_state_dict(checkpoint[optimizer_state_dict]) scaler.load_state_dict(checkpoint[scaler_state_dict]) early_stopping.best_score checkpoint.get(early_stopping_best_score) early_stopping.counter checkpoint.get(early_stopping_counter, 0) start_epoch checkpoint[epoch] 19. 模型部署FastAPI9.1 API设计将训练好的模型部署为Web服务对外提供HTTP预测接口。设计如下端点POST /predict请求体{text: 商品标题}响应体{text: 商品标题, pred_id: 12, pred_label: 连衣裙}FastAPI框架自动生成交互式API文档/docs便于测试。9.2 代码实现服务逻辑层service.py负责加载模型、分词器、标签映射并提供预测函数。# src/web/service.py import torch from datasets import load_from_disk from transformers import AutoTokenizer from configuration import config from model.classifier import BertTitleClassifier from runner.predict import predict_text device torch.device(cuda if torch.cuda.is_available() else cpu) model BertTitleClassifier().to(device) model.load_state_dict(torch.load(config.MODELS_DIR / model.pt, map_locationdevice)) model.eval() tokenizer AutoTokenizer.from_pretrained(str(config.PRE_TRAINED_DIR / bert-base-chinese)) label_feature load_from_disk(str(config.PROCESSED_DATA_DIR / train)).features[label] def predict_service(text: str): return predict_text(text, model, tokenizer, device, label_feature)数据模型schemas.py定义请求和响应的Pydantic模型。# src/web/schemas.py from pydantic import BaseModel class PredictRequest(BaseModel): text: str class PredictResponse(BaseModel): text: str pred_id: int pred_label: str路由routers.py定义API端点逻辑。# src/web/routers.py from fastapi import APIRouter, HTTPException from web.schemas import PredictRequest, PredictResponse from web.service import predict_service predict_router APIRouter(tags[预测接口]) predict_router.post(/predict, response_modelPredictResponse) def predict(request: PredictRequest): text request.text.strip() if not text: raise HTTPException(status_code400, detail输入文本不能为空) try: pred_id, pred_label predict_service(text) return PredictResponse(texttext, pred_idpred_id, pred_labelpred_label) except Exception as e: raise HTTPException(status_code500, detailf预测失败: {str(e)})主应用app.py创建FastAPI实例注册路由启动服务器。# src/web/app.py import uvicorn from fastapi import FastAPI from web.routers import predict_router app FastAPI(title商品标题分类API) app.include_router(predict_router) def run_app(): uvicorn.run(web.app:app, host0.0.0.0, port8000)9.3 接口测试启动服务后访问 http://localhost:8000/docs 可以看到自动生成的API文档并可直接在网页上测试预测接口。# 启动命令通过CLI python main.py serve # 或直接运行 uvicorn src.web.app:app --reload测试示例请求POST /predictBody{text: 唯美小清新连衣裙}响应{text: 唯美小清新连衣裙, pred_id: 5, pred_label: 连衣裙}10. 统一入口脚本为了方便执行不同任务数据处理、训练、评估、预测、服务编写一个统一的CLI入口脚本main.py使用argparse解析命令行参数。# main.py import argparse def main(): parser argparse.ArgumentParser(description商品标题分类器 CLI) parser.add_argument( action, choices[process, train, evaluate, predict, serve], help操作类型process(预处理) | train(训练) | evaluate(评估) | predict(交互预测) | serve(启动API) ) args parser.parse_args() if args.action process: from preprocess.process import process process() elif args.action train: from runner.train import train train() elif args.action evaluate: from runner.evaluate import run_evaluation run_evaluation() elif args.action predict: from runner.predict import run_predict run_predict() elif args.action serve: from web.app import run_app run_app() else: print(未知操作请选择process / train / evaluate / predict / serve) if __name__ __main__: main()使用示例python main.py process # 数据预处理 python main.py train # 训练模型 python main.py evaluate # 评估模型 python main.py predict # 交互式预测 python main.py serve # 启动API服务11. 项目总结本文从零开始完整实现了一个基于BERT的商品标题智能分类系统涵盖了从数据处理、模型构建、训练优化到生产部署的全流程。以下是关键收获BERT微调实战理解了[CLS]向量的作用以及如何添加分类头掌握了全参数微调与冻结训练的权衡。数据预处理规范学会了使用datasets库高效处理大规模文本数据以及BERT分词器的正确用法。训练优化技巧早停机制有效防止过拟合节省训练时间。混合精度训练AMP显著降低显存占用并加速训练。Checkpoint断点续训让长时间训练更加可靠。全面评估使用准确率、精确率、召回率、F1分数宏平均从多角度衡量模型性能。生产级部署通过FastAPI快速将模型封装为RESTful服务支持高并发调用并提供自动生成的API文档。本系统可直接应用于电商平台的商品发布环节帮助商家提升效率、减少错误。读者可根据实际业务需求替换数据集和调整超参数轻松迁移到其他文本分类场景如新闻分类、评论情感分析等。希望本文能为你在NLP工程化落地的道路上提供有价值的参考。如有疑问或改进建议欢迎交流讨论。