Pytorch图像去噪实战十四条件扩散模型图像去噪让Diffusion根据带噪图恢复干净图一、问题场景普通Diffusion能生成图但不能直接修复指定图片前面我们实现了 DDPM 和 DDIM。但如果你仔细看会发现之前的采样方式是从纯噪声开始生成图像这更像是生成任务。而真实图像去噪任务通常是给定一张带噪图输出它对应的干净图也就是说我们不是要随机生成图片而是要修复指定图片。这时普通无条件Diffusion就不够用了需要引入条件扩散模型 Conditional Diffusion二、条件扩散去噪的核心思想普通Diffusion输入x_t, t条件Diffusion输入x_t, noisy_condition, t其中x_t扩散过程中的 noisy clean imagenoisy_condition真实带噪图t时间步模型学习predict noise from x_t with condition也就是让模型在反向去噪时参考原始带噪图。三、为什么需要condition如果没有condition模型生成的是随机干净图不一定和输入图片内容一致。加入condition后模型知道图像结构是什么边缘在哪里文字位置在哪里物体轮廓在哪里因此它可以围绕输入图像做恢复而不是凭空生成。四、工程结构conditional_diffusion_denoise/ ├── data/ │ └── train/ ├── models/ │ └── conditional_unet.py ├── diffusion/ │ └── ddpm.py ├── dataset.py ├── train.py ├── infer.py └── utils.py五、数据集构造训练时我们有 clean 图然后人工加噪得到 condition。importosimportrandomimporttorchfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassConditionalDenoiseDataset(Dataset):def__init__(self,root_dir,image_size64):self.paths[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((.jpg,.png,.jpeg))]self.transformtransforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor()])def__len__(self):returnlen(self.paths)def__getitem__(self,index):cleanImage.open(self.paths[index]).convert(L)cleanself.transform(clean)sigmarandom.choice([15,25,35,50])noisetorch.randn_like(clean)*sigma/255.0noisy_conditiontorch.clamp(cleannoise,0.0,1.0)returnnoisy_condition,clean六、条件UNet模型核心改动非常简单把 x_t 和 noisy_condition 在通道维度拼接。如果是灰度图x_t: 1通道 condition: 1通道 concat后: 2通道models/conditional_unet.pyimporttorchimporttorch.nnasnnclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.netnn.Sequential(nn.Linear(1,dim),nn.SiLU(),nn.Linear(dim,dim))defforward(self,t):tt.float().view(-1,1)/1000.0returnself.net(t)classResBlock(nn.Module):def__init__(self,in_channels,out_channels,time_dim):super().__init__()self.conv1nn.Conv2d(in_channels,out_channels,3,padding1)self.conv2nn.Conv2d(out_channels,out_channels,3,padding1)self.time_projnn.Linear(time_dim,out_channels)self.shortcutnn.Identity()ifin_channels!out_channels:self.shortcutnn.Conv2d(in_channels,out_channels,1)self.actnn.SiLU()defforward(self,x,t_emb):hself.act(self.conv1(x))timeself.time_proj(t_emb).view(x.size(0),-1,1,1)hhtime hself.conv2(self.act(h))returnhself.shortcut(x)classConditionalUNet(nn.Module):def__init__(self,image_channels1,base64,time_dim128):super().__init__()self.time_mlpTimeEmbedding(time_dim)in_channelsimage_channels*2self.down1ResBlock(in_channels,base,time_dim)self.down2ResBlock(base,base*2,time_dim)self.poolnn.MaxPool2d(2)self.midResBlock(base*2,base*2,time_dim)self.upnn.ConvTranspose2d(base*2,base,2,2)self.up_blockResBlock(base*2,base,time_dim)self.outnn.Conv2d(base,image_channels,3,padding1)defforward(self,xt,condition,t):t_embself.time_mlp(t)xtorch.cat([xt,condition],dim1)d1self.down1(x,t_emb)d2self.down2(self.pool(d1),t_emb)midself.mid(d2,t_emb)uself.up(mid)utorch.cat([u,d1],dim1)uself.up_block(u,t_emb)returnself.out(u)七、训练代码importtorchfromtorch.utils.dataimportDataLoaderfromdatasetimportConditionalDenoiseDatasetfromdiffusion.ddpmimportDDPMfrommodels.conditional_unetimportConditionalUNetdeftrain():devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)datasetConditionalDenoiseDataset(data/train,image_size64)loaderDataLoader(dataset,batch_size16,shuffleTrue,num_workers4)modelConditionalUNet().to(device)diffusionDDPM(timesteps1000,beta_start1e-4,beta_end0.02,devicedevice)optimizertorch.optim.AdamW(model.parameters(),lr2e-4)criteriontorch.nn.MSELoss()forepochinrange(1,101):model.train()total_loss0forcondition,cleaninloader:conditioncondition.to(device)cleanclean.to(device)ttorch.randint(0,diffusion.timesteps,(clean.size(0),),devicedevice)xt,noisediffusion.q_sample(clean,t)pred_noisemodel(xt,condition,t)losscriterion(pred_noise,noise)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_lossloss.item()print(fEpoch{epoch}, Loss:{total_loss/len(loader):.6f})ifepoch%100:torch.save(model.state_dict(),fconditional_diffusion_epoch_{epoch}.pth)if__name____main__:train()八、推理代码推理时输入一张真实 noisy image 作为 condition。importtorchfromPILimportImageimporttorchvision.transformsastransformsimporttorchvision.utilsasvutilsfromdiffusion.ddpmimportDDPMfrommodels.conditional_unetimportConditionalUNettorch.no_grad()definfer():devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelConditionalUNet().to(device)model.load_state_dict(torch.load(conditional_diffusion_epoch_100.pth,map_locationdevice))model.eval()diffusionDDPM(timesteps1000,beta_start1e-4,beta_end0.02,devicedevice)imgImage.open(test_noisy.png).convert(L)transformtransforms.Compose([transforms.Resize((64,64)),transforms.ToTensor()])conditiontransform(img).unsqueeze(0).to(device)xtorch.randn_like(condition)fortinreversed(range(diffusion.timesteps)):batch_ttorch.full((1,),t,devicedevice,dtypetorch.long)pred_noisemodel(x,condition,batch_t)betadiffusion.betas[t]alphadiffusion.alphas[t]alpha_bardiffusion.alpha_bars[t]x(1/torch.sqrt(alpha))*(x-(beta/torch.sqrt(1-alpha_bar))*pred_noise)ift0:xxtorch.sqrt(beta)*torch.randn_like(x)xtorch.clamp(x,0.0,1.0)vutils.save_image(x.cpu(),conditional_denoised.png)if__name____main__:infer()九、为什么条件图不能直接作为初始x很多人第一次写条件扩散时会想直接从 noisy image 开始反向去噪不就行了但标准条件扩散里反向过程的变量 x 是目标 clean 的扩散状态而 noisy image 是条件信息。两者角色不同x当前正在生成的 clean image 状态condition引导恢复的输入图如果混在一起模型训练和推理分布会不一致。十、和普通UNet去噪相比有什么优势普通UNetnoisy - clean条件Diffusionnoise state noisy condition - clean distribution优势在于更适合复杂噪声可以生成更自然细节对强噪声恢复潜力更高缺点也明显训练更慢推理更慢工程复杂度更高十一、踩坑记录坑1condition没有拼接进模型如果模型只输入 xt 和 t那就是无条件生成不是图像去噪。坑2condition和clean尺寸不一致训练时 condition 和 clean 必须尺寸一致。建议在 dataset 中统一 resize。坑3采样太慢条件Diffusion同样有1000步采样问题。建议后续结合DDIM。十二、适合收藏总结条件Diffusion去噪流程从clean构造noisy condition对clean执行扩散加噪模型输入 xt condition t模型预测noise推理时用condition引导反向去噪避坑清单condition必须输入模型clean和condition尺寸一致x和condition角色不要混推理成本较高建议结合DDIM加速十三、优化建议可以继续做条件DDIM采样加强UNet结构使用Restormer作为条件网络支持RGB图像用真实噪声数据微调结尾总结条件扩散模型把Diffusion从“随机生成图像”推进到“指定图像恢复”。它的核心价值是既保留扩散模型强大的生成能力又让模型受输入带噪图约束。如果你要把Diffusion用于真正的图像去噪任务条件扩散是必须掌握的一步。下一篇预告Pytorch图像去噪实战十五彩色RGB图像去噪实战从灰度模型升级到真实图片处理