ConvNeXt网络架构解析、代码实现与实战训练
1. ConvNeXt网络架构解析ConvNeXt是2022年由Facebook AI Research和UC Berkeley联合提出的纯卷积神经网络架构。这个网络的设计初衷很有意思 - 研究人员想看看如果让传统的卷积神经网络(CNN)借鉴一些Transformer的成功经验能达到什么样的效果。结果出人意料ConvNeXt不仅超越了传统的ResNet甚至在某些方面比当时火热的Swin Transformer表现更好。我第一次看到这个网络结构时最惊讶的是它竟然只用了标准的卷积操作没有使用任何注意力机制却能达到与Transformer相当的性能。在ImageNet 22K数据集上最大的ConvNeXt-XL模型甚至达到了87.8%的top-1准确率。1.1 宏观设计改进ConvNeXt的基础是ResNet50但做了几个关键改进。首先是调整了stage的计算比例。在原始的ResNet中四个stage的block堆叠次数是(3,4,6,3)而ConvNeXt参考Swin Transformer的比例改成了(3,3,9,3)。这个改动虽然简单但效果立竿见影准确率从78.8%提升到了79.4%。另一个重要改变是stem层的设计。传统ResNet使用一个stride2的7x7卷积加最大池化层而ConvNeXt采用了更接近ViT的patchify思路使用一个卷积核大小为4、步距为4的卷积层。这个改动让准确率又提升了0.1个百分点。1.2 微观结构优化在微观层面ConvNeXt借鉴了ResNeXt的设计采用了深度可分离卷积(depthwise convolution)。这种卷积的特点是分组数等于输入通道数大大减少了计算量。实测下来这个改动让准确率从79.5%跃升到80.5%。另一个巧妙的设计是采用了反瓶颈结构(inverted bottleneck)这与MobileNetV2的思路类似。简单来说就是在block中间使用更宽的通道数(通常是输入通道的4倍)而输入输出层保持较窄的通道。同时ConvNeXt还调整了卷积层的位置将depthwise conv模块上移这样改动后准确率又提升了0.1%。2. ConvNeXt代码实现详解现在我们来深入看看ConvNeXt的PyTorch实现代码。我建议你边看边动手实践这样理解会更深刻。2.1 基础模块实现ConvNeXt的核心是它的Block模块我们先来看这部分代码class Block(nn.Module): def __init__(self, dim, drop_rate0., layer_scale_init_value1e-6): super().__init__() self.dwconv nn.Conv2d(dim, dim, kernel_size7, padding3, groupsdim) self.norm LayerNorm(dim, eps1e-6, data_formatchannels_last) self.pwconv1 nn.Linear(dim, 4 * dim) self.act nn.GELU() self.pwconv2 nn.Linear(4 * dim, dim) self.gamma nn.Parameter(layer_scale_init_value * torch.ones((dim,))) self.drop_path DropPath(drop_rate) if drop_rate 0. else nn.Identity() def forward(self, x): shortcut x x self.dwconv(x) x x.permute(0, 2, 3, 1) # [N, C, H, W] - [N, H, W, C] x self.norm(x) x self.pwconv1(x) x self.act(x) x self.pwconv2(x) if self.gamma is not None: x self.gamma * x x x.permute(0, 3, 1, 2) # [N, H, W, C] - [N, C, H, W] x shortcut self.drop_path(x) return x这个Block有几个关键点值得注意使用了7x7的大卷积核这是为了与Swin Transformer的窗口大小对齐采用了通道最后的格式(channels_last)来应用LayerNorm使用线性层来实现1x1卷积这种实现方式在PyTorch中往往更快引入了Layer Scale机制这是一个可学习的缩放参数2.2 完整网络结构完整的ConvNeXt网络由多个stage组成每个stage包含若干个上述Block。下面是网络的定义代码class ConvNeXt(nn.Module): def __init__(self, in_chans3, num_classes1000, depths[3, 3, 9, 3], dims[96, 192, 384, 768], drop_path_rate0., layer_scale_init_value1e-6, head_init_scale1.): super().__init__() self.downsample_layers nn.ModuleList() stem nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size4, stride4), LayerNorm(dims[0], eps1e-6, data_formatchannels_first) ) self.downsample_layers.append(stem) for i in range(3): downsample_layer nn.Sequential( LayerNorm(dims[i], eps1e-6, data_formatchannels_first), nn.Conv2d(dims[i], dims[i1], kernel_size2, stride2) ) self.downsample_layers.append(downsample_layer) self.stages nn.ModuleList() dp_rates [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur 0 for i in range(4): stage nn.Sequential( *[Block(dimdims[i], drop_ratedp_rates[cur j], layer_scale_init_valuelayer_scale_init_value) for j in range(depths[i])] ) self.stages.append(stage) cur depths[i] self.norm nn.LayerNorm(dims[-1], eps1e-6) self.head nn.Linear(dims[-1], num_classes) self.apply(self._init_weights) self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.trunc_normal_(m.weight, std0.2) nn.init.constant_(m.bias, 0) def forward_features(self, x): for i in range(4): x self.downsample_layers[i](x) x self.stages[i](x) return self.norm(x.mean([-2, -1])) # global average pooling def forward(self, x): x self.forward_features(x) x self.head(x) return x这段代码有几个关键设计使用4x4 stride4的卷积作为stem层每个stage之间使用LayerNorm2x2卷积进行下采样采用了随机深度(drop path)正则化技术使用全局平均池化和LayerNorm作为最后的特征处理3. 实战训练ConvNeXt模型理论讲得再多不如实际动手训练一次。下面我将带你完整走一遍ConvNeXt的训练流程使用的是花卉分类数据集。3.1 数据准备与增强首先我们需要准备数据并进行适当的增强。这里我使用了PyTorch的transforms模块from torchvision import transforms img_size 224 data_transform { train: transforms.Compose([ transforms.RandomResizedCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), val: transforms.Compose([ transforms.Resize(int(img_size * 1.143)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) }这里有几个要点训练时使用随机裁剪和水平翻转增强数据验证时使用中心裁剪使用ImageNet的均值和标准差进行归一化验证时先将图像放大到256x256再中心裁剪到224x2243.2 模型训练技巧ConvNeXt的训练有一些特别的技巧我整理了几个最重要的优化器选择使用AdamW优化器这是Transformer模型常用的优化器学习率调度采用余弦退火学习率并带有warmup阶段权重衰减设置较大的权重衰减(0.05)来防止过拟合标签平滑可以使用标签平滑技术提升模型泛化能力下面是优化器设置的代码示例from torch import optim from utils import get_params_groups pg get_params_groups(model, weight_decayargs.wd) optimizer optim.AdamW(pg, lrargs.lr, weight_decayargs.wd) lr_scheduler create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmupTrue)3.3 训练过程监控训练过程中我们可以使用TensorBoard来监控各项指标from torch.utils.tensorboard import SummaryWriter writer SummaryWriter(log_dirsave_path /flower_experiment) # 在训练循环中添加 writer.add_scalar(train_loss, train_loss, epoch) writer.add_scalar(train_acc, train_acc, epoch) writer.add_scalar(val_loss, val_loss, epoch) writer.add_scalar(val_acc, val_acc, epoch) writer.add_scalar(learning_rate, optimizer.param_groups[0][lr], epoch)这样我们可以实时查看损失和准确率的变化曲线方便调整超参数。4. 模型评估与预测训练完成后我们需要评估模型性能并进行实际预测。4.1 模型评估评估过程相对简单主要是计算模型在验证集上的准确率torch.no_grad() def evaluate(model, data_loader, device, epoch): model.eval() accu_num torch.zeros(1).to(device) accu_loss torch.zeros(1).to(device) sample_num 0 for step, data in enumerate(data_loader): images, labels data sample_num images.shape[0] pred model(images.to(device)) pred_classes torch.max(pred, dim1)[1] accu_num torch.eq(pred_classes, labels.to(device)).sum() loss loss_function(pred, labels.to(device)) accu_loss loss return accu_loss.item() / (step 1), accu_num.item() / sample_num4.2 单张图像预测对于单张图像的预测我们需要确保输入图像经过与训练时相同的预处理def predict(image_path, model, transform, class_indices, device): img Image.open(image_path).convert(RGB) img transform(img).unsqueeze(0).to(device) model.eval() with torch.no_grad(): output torch.squeeze(model(img)) predict torch.softmax(output, dim0) predict_cla torch.argmax(predict).item() return class_indict[str(predict_cla)], predict[predict_cla].item()4.3 批量预测实际应用中我们往往需要批量预测大量图像def batch_predict(img_path_list, model, transform, batch_size8): model.eval() all_preds [] with torch.no_grad(): for i in range(0, len(img_path_list), batch_size): batch_paths img_path_list[i:ibatch_size] batch_images [transform(Image.open(p).convert(RGB)) for p in batch_paths] batch_tensor torch.stack(batch_images).to(device) outputs model(batch_tensor) preds torch.argmax(outputs, dim1) all_preds.extend(preds.cpu().numpy()) return all_preds在实际项目中ConvNeXt-Tiny模型在花卉分类数据集上训练20个epoch就能达到92%以上的准确率表现相当不错。如果增加训练轮数或使用更大的模型准确率还能进一步提升。