用PyTorch复现CycleGAN从零开始手搓一个风格迁移模型附完整代码与调试心得风格迁移一直是计算机视觉领域的热门研究方向而CycleGAN作为其中的佼佼者以其无需配对数据的特性脱颖而出。本文将带你从零开始用PyTorch完整复现CycleGAN并分享在实际编码和调试过程中的关键经验。1. 理解CycleGAN的核心架构CycleGAN的核心思想在于循环一致性——它由两个生成器和两个判别器组成形成一个闭环系统。与传统的GAN不同CycleGAN不需要成对的训练数据这使得它在许多实际应用中更具优势。生成器网络的关键组件下采样卷积层逐步减小特征图尺寸残差块9个残差块构成的核心转换模块上采样层通过转置卷积恢复原始尺寸class GeneratorResNet(nn.Module): def __init__(self, input_shape, num_residual_blocks): super().__init__() channels input_shape[0] out_features 64 model [ nn.ReflectionPad2d(channels), nn.Conv2d(channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplaceTrue) ] # 下采样 for _ in range(2): out_features * 2 model [ nn.Conv2d(in_features, out_features, 3, stride2, padding1), nn.InstanceNorm2d(out_features), nn.ReLU(inplaceTrue) ] # 残差块 for _ in range(num_residual_blocks): model [ResidualBlock(out_features)] # 上采样 for _ in range(2): out_features // 2 model [ nn.Upsample(scale_factor2), nn.Conv2d(in_features, out_features, 3, stride1, padding1), nn.InstanceNorm2d(out_features), nn.ReLU(inplaceTrue) ] # 输出层 model [ nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh() ] self.model nn.Sequential(*model)2. 数据准备与预处理CycleGAN对数据格式有特定要求。我们需要将不同风格的图像分别放在trainA和trainB文件夹中。预处理步骤包括随机裁剪到256x256像素随机水平翻转增加数据多样性归一化到[-1,1]范围数据集类实现要点class ImageDataset(Dataset): def __init__(self, root, transforms_None, unalignedFalse, modetrain): self.transform transforms.Compose(transforms_) self.unaligned unaligned self.files_A sorted(glob.glob(os.path.join(root, f{mode}A/*.*))) self.files_B sorted(glob.glob(os.path.join(root, f{mode}B/*.*))) def __getitem__(self, index): image_A Image.open(self.files_A[index % len(self.files_A)]) if self.unaligned: image_B Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]) else: image_B Image.open(self.files_B[index % len(self.files_B)]) if image_A.mode ! RGB: image_A to_rgb(image_A) if image_B.mode ! RGB: image_B to_rgb(image_B) return {A: self.transform(image_A), B: self.transform(image_B)}3. 训练过程中的关键技巧3.1 损失函数配置CycleGAN使用三种主要损失函数损失类型计算公式权重(λ)作用GAN损失MSE1.0使生成图像更真实循环一致性损失L110.0保持内容一致性身份损失L15.0保持颜色分布# 初始化损失函数 criterion_GAN torch.nn.MSELoss() criterion_cycle torch.nn.L1Loss() criterion_identity torch.nn.L1Loss() # 在训练循环中计算总损失 loss_G loss_GAN opt.lambda_cyc * loss_cycle opt.lambda_id * loss_identity3.2 ReplayBuffer的妙用ReplayBuffer是CycleGAN训练中的一个关键技巧它存储之前生成的图像用于判别器训练class ReplayBuffer: def __init__(self, max_size50): self.max_size max_size self.data [] def push_and_pop(self, data): to_return [] for element in data.data: element torch.unsqueeze(element, 0) if len(self.data) self.max_size: self.data.append(element) to_return.append(element) else: if random.uniform(0, 1) 0.5: i random.randint(0, self.max_size - 1) to_return.append(self.data[i].clone()) self.data[i] element else: to_return.append(element) return Variable(torch.cat(to_return))3.3 学习率调度策略采用线性衰减的学习率策略前30个epoch保持恒定之后线性衰减到0class LambdaLR: def __init__(self, n_epochs, offset, decay_start_epoch): self.n_epochs n_epochs self.offset offset self.decay_start_epoch decay_start_epoch def step(self, epoch): return 1.0 - max(0, epoch self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)4. 调试与优化经验分享在实际复现过程中我遇到了几个关键问题及解决方案模式崩溃问题现象生成器总是输出相似的图像解决调整判别器的PatchGAN感受野大小增加判别器的深度训练不稳定现象损失值剧烈波动解决使用较小的学习率(0.0002)并增加批归一化层颜色失真现象生成图像颜色分布异常解决引入身份损失(identity loss)权重设为5.0内存不足现象GPU内存爆满解决减小批处理大小使用梯度累积技巧训练监控建议每100次迭代保存一次生成样本监控四种损失值的变化趋势定期检查生成图像的多样性def sample_images(batches_done): imgs next(iter(val_dataloader)) G_AB.eval() G_BA.eval() real_A Variable(imgs[A]).cuda() fake_B G_AB(real_A) real_B Variable(imgs[B]).cuda() fake_A G_BA(real_B) # 拼接并保存对比图像 image_grid torch.cat((real_A, fake_B, real_B, fake_A), 1) save_image(image_grid, fimages/{opt.dataset_name}/{batches_done}.png, normalizeFalse)5. 完整训练流程实现以下是训练循环的核心代码结构def train(): for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # 1. 准备真实图像和标签 real_A Variable(batch[A]).cuda() real_B Variable(batch[B]).cuda() valid Variable(torch.ones(real_A.size(0), *D_A.output_shape), requires_gradFalse).cuda() fake Variable(torch.zeros(real_A.size(0), *D_A.output_shape), requires_gradFalse).cuda() # 2. 训练生成器 optimizer_G.zero_grad() loss_G compute_generator_loss(real_A, real_B, valid) loss_G.backward() optimizer_G.step() # 3. 训练判别器A optimizer_D_A.zero_grad() loss_D_A compute_discriminator_loss(real_A, fake_A, D_A, valid, fake) loss_D_A.backward() optimizer_D_A.step() # 4. 训练判别器B optimizer_D_B.zero_grad() loss_D_B compute_discriminator_loss(real_B, fake_B, D_B, valid, fake) loss_D_B.backward() optimizer_D_B.step() # 5. 打印日志和保存样本 if batches_done % opt.sample_interval 0: sample_images(batches_done) # 更新学习率 lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step()6. 测试与应用训练完成后我们可以使用训练好的模型进行风格转换def test(): # 加载训练好的模型 netG_A2B GeneratorResNet(input_shape, opt.n_residual_blocks).cuda() netG_A2B.load_state_dict(torch.load(opt.generator_A2B)) netG_A2B.eval() # 处理测试图像 for i, batch in enumerate(test_dataloader): real_A Variable(input_A.copy_(batch[A])).cuda() fake_B 0.5 * (netG_A2B(real_A).data 1.0) save_image(fake_B, foutput/B/{i1:04d}.png)在实际项目中我发现以下几个参数调整对结果影响最大循环一致性损失权重10.0是一个较好的起点身份损失权重5.0可以较好地保持颜色分布残差块数量9个块在256x256图像上效果最佳InstanceNorm的使用比BatchNorm更适合风格迁移任务经过约50个epoch的训练后模型能够产生令人满意的风格转换效果。在facades数据集上从建筑照片到建筑素描的转换效果尤为出色。