扩散模型中的U-Net
U-Net 是一种深度学习网络架构广泛应用于扩散模型如 DDPM、Stable Diffusion的反向去噪过程中。它的核心作用是预测噪声分布从而从噪声数据逐步重建清晰图像。以下我将逐步解释 U-Net 的结构、功能和在扩散模型中的关键机制。U-Net网络最早由Olaf Ronneberger、Philipp Fischer 和 Thomas Brox在2015年发表的论文中提出。这篇开创性的论文名为《U-Net: Convolutional Networks for Biomedical Image Segmentation》。1.U-Net 的基本架构U-Net 起源于医学图像分割其设计包含编码器下采样和解码器上采样路径并引入跳跃连接skip connections来保留空间细节。在扩散模型中U-Net 被适配用于处理时间序列数据U-net 架构最低分辨率下 32x32 像素的示例。每个蓝色框对应一个多通道特征图。通道数显示在框的顶部。 x-y-尺寸位于框的左下边缘。白框代表复制的特征图。箭头表示不同的操作。①conv 3×3, ReLU— 特征提取的基本单元conv 3×33×3 的卷积操作是最常用的局部特征提取器。它对输入特征图的每一个 3×3 邻域做加权求和用来提取边缘、纹理、几何模式等局部信息。ReLU:线性整流函数让所有值均大于等于0是跟在卷积后面的非线性激活函数。没有它多层卷积堆叠后也只是一个线性变换表达能力会大幅受限。在 U-Net 中两个 3×3 卷积 ReLU 串联组成一个“卷积块”反复出现在编码器和解码器中是整个网络的基本工作单元。②copy and crop— 跳跃连接融合高低层信息这是 U-Net 最精妙的设计之一。copy把编码器中某一层的特征图直接拷贝出来crop因为编码器曾经过 max pooling 或 valid padding特征图的尺寸可能与解码器中对应层的特征图不完全一致所以需要裁剪到相同大小然后它们被拼接concatenate在通道维度上与解码器特征图一起送入后续卷积。这样做的目的高层特征包含强语义信息“这是什么”但空间细节粗糙低层特征保留细粒度的空间细节边缘、位置跳跃连接让解码器在恢复分辨率时既能看到高层语义又能参照低层的精确空间结构对边缘保持、精细结构还原至关重要。③max pool 2×2— 下采样扩大感受野max pool 2×22×2 的最大池化在每个 2×2 窗口内取最大值作为输出。在 U-Net 的编码器下采样路径中每次 max pooling 之后跟一个卷积块形成层级式的特征抽象。这也是“从局部逐步走向全局”的关键一步。④up-conv 2×2— 上采样恢复空间分辨率up-conv 2×2也叫转置卷积transposed convolution或反卷积负责把特征图的分辨率放大。它会将输入特征图的每个点映射到一个 2×2 区域上使长宽各扩大一倍。在 U-Net 的解码器上采样路径中每一次 up-conv 将特征图放大后接着再与对应编码器层的特征图进行跳跃连接copy and crop然后继续用卷积块处理。没有 up-conv解码器就无法把压缩后的高层语义逐步恢复回原始分辨率也就没法做逐像素的稠密预测如去噪、分割。⑤conv 1×1— 输出映射调整通道数conv 1×1卷积核大小为 1×1 的卷积不改变空间分辨率只改变通道数。常出现在网络的最后用来将多通道的特征图映射到目标输出维度。2.编码器Encoder通过卷积层逐步降低特征图的分辨率提取高层抽象特征。模型结构的左半部分是编码器。编码器开始先做一次双卷积增加原始图像的通道数随后重复 4 次下采样双卷积。下采样用 2×2 最大池化将分辨率减半双卷积用两次 3×3 卷积padding1, stride1在该尺度上提取特征。随着网络层数的加深U 形的下降特征图的空间尺寸越来越小但特征的“语义”级别越来越高。网络从关注像素级的细节逐渐转变为理解“这里可能是一个细胞”、“那是一块背景”等高层概念。给出U-Net编码器的代码仅供参考import torch import torch.nn as nn class UNetEncoder(nn.Module): def __init__(self, in_channels3, init_features64): U-Net编码器模块 :param in_channels: 输入通道数 (默认3-RGB图像) :param init_features: 初始特征图数量 (默认64) super(UNetEncoder, self).__init__() features init_features # 定义4个下采样阶段 self.encoder1 self._block(in_channels, features, nameenc1) self.pool1 nn.MaxPool2d(kernel_size2, stride2) self.encoder2 self._block(features, features * 2, nameenc2) self.pool2 nn.MaxPool2d(kernel_size2, stride2) self.encoder3 self._block(features * 2, features * 4, nameenc3) self.pool3 nn.MaxPool2d(kernel_size2, stride2) self.encoder4 self._block(features * 4, features * 8, nameenc4) self.pool4 nn.MaxPool2d(kernel_size2, stride2) # 最底层瓶颈层 self.bottleneck self._block(features * 8, features * 16, namebottleneck) def _block(self, in_channels, features, name): 构建基础卷积块 :return: 包含两个卷积层的序列 return nn.Sequential( nn.Conv2d( in_channelsin_channels, out_channelsfeatures, kernel_size3, padding1, biasFalse ), nn.BatchNorm2d(num_featuresfeatures), nn.ReLU(inplaceTrue), nn.Conv2d( in_channelsfeatures, out_channelsfeatures, kernel_size3, padding1, biasFalse ), nn.BatchNorm2d(num_featuresfeatures), nn.ReLU(inplaceTrue) ) def forward(self, x): 前向传播过程 :return: 各层特征图 瓶颈层输出 # 第一层 enc1 self.encoder1(x) enc1_pool self.pool1(enc1) # 第二层 enc2 self.encoder2(enc1_pool) enc2_pool self.pool2(enc2) # 第三层 enc3 self.encoder3(enc2_pool) enc3_pool self.pool3(enc3) # 第四层 enc4 self.encoder4(enc3_pool) enc4_pool self.pool4(enc4) # 瓶颈层 bottleneck self.bottleneck(enc4_pool) # 返回各层特征图用于跳跃连接和瓶颈层输出 return [enc1, enc2, enc3, enc4], bottleneck # 使用示例 if __name__ __main__: # 创建编码器 (输入3通道图像) encoder UNetEncoder(in_channels3) # 模拟输入 (batch_size4, 3通道, 256x256图像) input_tensor torch.randn(4, 3, 256, 256) # 前向传播 features, bottleneck encoder(input_tensor) # 打印输出尺寸 print(编码器输出:) for i, feat in enumerate(features): print(f层{i1}特征图尺寸: {feat.shape}) print(f瓶颈层输出尺寸: {bottleneck.shape})关键组件说明卷积块每个编码阶段包含两个3×3卷积层每层后接批量归一化和ReLU激活下采样每个阶段后使用2×2最大池化步长2进行特征图降维特征图扩展每下采样一次特征图数量翻倍64→128→256→512瓶颈层位于编码器末端特征图512→1024输出特征图尺寸变化假设输入256×256层级输出尺寸 (C×H×W)说明输入3×256×256原始输入编码器164×256×256无下采样池化164×128×128第一次下采样编码器2128×128×128池化2128×64×64第二次下采样编码器3256×64×64池化3256×32×32第三次下采样编码器4512×32×32池化4512×16×16第四次下采样瓶颈层1024×16×16最深层特征此实现完全遵循U-Net原始架构输出的各层特征图可用于后续解码器的跳跃连接操作。3.解码器Decoder通过转置卷积层逐步上采样特征图恢复空间分辨率。解码器整体结构与编码器对称从最深层特征 x5出发连续做 4 次上采样拼接双卷积每上采样一次H、W 乘 2同时通道数逐层下降最终回到与 x1 相同尺度再用 1×1 卷积输出类别.给出U-Net解码器的代码:包含上采样和跳跃连接结构。U-Net解码器负责将低分辨率特征图逐步恢复为高分辨率分割结果import torch import torch.nn as nn class UNetDecoder(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() # 上采样模块转置卷积 卷积块 self.up1 UpBlock(in_channels[0], in_channels[1]) self.up2 UpBlock(in_channels[1], in_channels[2]) self.up3 UpBlock(in_channels[2], in_channels[3]) self.up4 UpBlock(in_channels[3], in_channels[4]) # 最终输出层 self.final_conv nn.Conv2d(in_channels[4], out_channels, kernel_size1) def forward(self, x, skip_connections): x: 编码器输出的瓶颈特征图 skip_connections: 编码器各阶段的特征图列表 (浅-深顺序) # 反向使用跳跃连接深-浅 x self.up1(x, skip_connections[3]) # 连接最深层的跳跃特征 x self.up2(x, skip_connections[2]) x self.up3(x, skip_connections[1]) x self.up4(x, skip_connections[0]) # 连接最浅层的跳跃特征 return self.final_conv(x) class UpBlock(nn.Module): 上采样模块转置卷积 特征拼接 双卷积 def __init__(self, in_channels, out_channels): super().__init__() # 转置卷积实现2倍上采样 self.upsample nn.ConvTranspose2d( in_channels, out_channels, kernel_size2, stride2 ) # 特征拼接后的卷积处理 self.conv_block DoubleConv( out_channels * 2, # 拼接后通道翻倍 out_channels ) def forward(self, x, skip): # 上采样 x self.upsample(x) # 跳跃连接的特征拼接通道维度 x torch.cat([x, skip], dim1) return self.conv_block(x) class DoubleConv(nn.Module): 双卷积模块Conv2d - ReLU - Conv2d - ReLU def __init__(self, in_channels, out_channels): super().__init__() self.layers nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size3, padding1), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, kernel_size3, padding1), nn.ReLU(inplaceTrue) ) def forward(self, x): return self.layers(x)使用说明初始化参数# in_channels: 各上采样层输入通道数 [瓶颈层, 深层, 中层, 浅层] # out_channels: 分割类别数 decoder UNetDecoder(in_channels[1024, 512, 256, 128], out_channels3)前向传播# bottleneck: 编码器输出的瓶颈特征图 (形状: [B, 1024, H/16, W/16]) # skips: 编码器4个阶段的特征图列表 (顺序: 浅-深) output decoder(bottleneck, skips)结构特点4次上采样操作恢复原始分辨率通过torch.cat实现跳跃连接每个上采样块包含转置卷积和特征融合最终使用$1\times1$卷积输出分割结果此实现遵循标准U-Net架构通过跳跃连接融合不同尺度的特征信息有效提升分割精度。整体结构U-Net 形似 U 形编码器和解码器对称中间有瓶颈层。在扩散模型中输入通常是噪声图像和对应时间步输出是预测的噪声。4.时间步长嵌入扩散模型的反向过程依赖于时间步表示噪声添加的程度。U-Net 通过时间嵌入time embedding将注入网络时间步首先被编码为高维向量通常使用正弦位置编码其中是嵌入维度例如 128 或 256。这个嵌入向量通过全连接层调整后添加到每个卷积块的激活中确保网络能区分不同噪声水平。5.在扩散模型中的作用在扩散模型的反向去噪过程中U-Net 预测当前噪声输入噪声图像和时间步。输出预测噪声目标是最小化与真实噪声的均方误差 其中是累积噪声调度参数。去噪过程通过迭代应用 U-Net从纯噪声逐步生成图像 这里和是噪声调度参数。6.简化代码示例以下是 PyTorch 实现的简化 U-Net 模型用于扩散模型仅展示核心部分import torch import torch.nn as nn import torch.nn.functional as F class TimeEmbedding(nn.Module): def __init__(self, embed_dim): super().__init__() self.embed_dim embed_dim # 正弦位置编码 self.proj nn.Linear(embed_dim, embed_dim) # 调整嵌入维度 def forward(self, t): # 生成正弦嵌入 half_dim self.embed_dim // 2 emb torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb torch.exp(-emb * torch.arange(half_dim, devicet.device)) emb t[:, None] * emb[None, :] emb torch.cat([torch.sin(emb), torch.cos(emb)], dim-1) emb self.proj(emb) return emb class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, time_embed_dim): super().__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, padding1) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, padding1) self.time_proj nn.Linear(time_embed_dim, out_channels) # 时间嵌入投影 def forward(self, x, t_emb): # 添加时间嵌入到激活 t_emb self.time_proj(t_emb)[:, :, None, None] # 调整形状匹配特征图 x F.relu(self.conv1(x)) x x t_emb # 注入时间信息 x F.relu(self.conv2(x)) return x class UNet(nn.Module): def __init__(self, in_channels3, out_channels3, time_embed_dim128): super().__init__() # 编码器路径 self.enc1 UNetBlock(in_channels, 64, time_embed_dim) self.enc2 UNetBlock(64, 128, time_embed_dim) # 解码器路径简化无跳跃连接 self.dec1 UNetBlock(128, 64, time_embed_dim) self.final_conv nn.Conv2d(64, out_channels, kernel_size1) self.time_embed TimeEmbedding(time_embed_dim) def forward(self, x, t): t_emb self.time_embed(t) # 时间嵌入 # 编码器 x1 self.enc1(x, t_emb) x2 F.max_pool2d(x1, 2) # 下采样 x3 self.enc2(x2, t_emb) # 解码器简化上采样 x4 F.interpolate(x3, scale_factor2, modenearest) # 上采样 x5 self.dec1(x4, t_emb) return self.final_conv(x5) # 输出预测噪声代码说明TimeEmbedding模块处理时间步嵌入。UNetBlock是基础块包含卷积层和时间嵌入注入。完整 U-Net 通常有多个下采样/上采样层和跳跃连接此处省略以简化。使用时输入噪声图像x和时间步t输出预测噪声。7.实际应用与优化性能U-Net 在扩散模型中高效因为跳跃连接保留细节适合图像生成任务。变体如 DDPM 使用残差块Stable Diffusion 结合注意力机制处理高分辨率。训练技巧使用 Adam 优化器学习率调度并在大规模数据集如 ImageNet上预训练。本文来源于网络学习后通过个人总结等完成感谢各位前辈的总结如有不妥或有误的地方欢迎大家来讨论批评指正。