脑图可拖动和放缩
在深度学习领域,卷积神经网络(CNN)一直是计算机视觉任务的主流架构。近年来,随着对模型效率要求的提高,各种卷积变体和优化技术层出不穷。本文将深入探讨标准卷积、深度卷积以及结构重参数化技术,特别是其中的卷积-BN融合、并行卷积融合和串行卷积融合技术。
标准卷积和深度卷积(DWConv)
标准卷积和深度卷积不做过多介绍,示意图如下

结构重参数(Structural Re-parameterization)
结构重参数化(Structural Re-parameterization)是一种训练-推理解耦的技术,在训练时使用复杂的结构,在推理时转换为简单的结构,既保证了训练时的性能,又提高了推理效率。
经典论文是CVPR2021 RepVGG提出的并行重参数结构,示意图如下

训练阶段,每个block包括$3\times3$,$1\times1$和一个恒等映射,按照论文结论,BN的引入可以增加模型训练时候的拟合能力
测试阶段,串行的BN和卷积可以合并,并行的不同卷积可以统一合并成$3\times3$的标准卷积
对于测试来说,结构重参数是一种“免费的午餐”,可以无损提升模型性能【丁博士知乎解答】
卷积-BN融合
将卷积层和后续的BN层融合为一个卷积层,减少推理时的计算量。
融合原理
BN层的计算:
$$y_{bn} = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
融合后的卷积权重和偏置:
$$W_{fused} = \gamma \cdot W / \sqrt{\sigma^2 + \epsilon}$$
$$b_{fused} = \gamma \cdot (b - \mu) / \sqrt{\sigma^2 + \epsilon} + \beta$$
示例代码
def fuse_conv_bn(conv, bn):
w = conv.weight
b = conv.bias if conv.bias is not None else torch.zeros(w.size(0)).to(w.device)
gamma, beta = bn.weight, bn.bias
mean, var = bn.running_mean, bn.running_var
std = (var + bn.eps).sqrt()
# 融合公式推导
w_fused = w * (gamma / std).reshape(-1, 1, 1, 1)
b_fused = beta + (gamma / std) * (b - mean)
return w_fused, b_fused
# 验证融合前后等价
conv = nn.Conv2d(3, 16, 3, padding=1, bias=True)
bn = nn.BatchNorm2d(16)
conv.eval()
bn.eval()
x = torch.rand(1, 3, 32, 32)
# 直接前馈
y = bn(conv(x))
# 融合前馈
w_fused, b_fused = fuse_conv_bn(conv, bn)
conv_fused = nn.Conv2d(3, 16, 3, padding=1, bias=True)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)
print('卷积BN融合最大误差:', (y-y_fused).abs().max().item()) #接近0,1e-7量级
并行卷积融合
将多个并行分支的卷积融合为一个卷积。
实现原理
多个卷积核可以相加合并:
$$W_{fused} = \sum_i W_i$$
$$b_{fused} = \sum_i b_i$$
公式中默认卷积核尺寸相同,不同则需padding至相同尺寸
示例代码(RepVGG结构:conv1 + conv3 + identity)
# 创建恒等映射核
def get_identity_kernel_groupwise(channels, groups):
per_group = channels // groups
# 支持分组卷积,组内进行标准卷积
weight = torch.zeros((channels, per_group, 1, 1))
for g in range(groups):
for i in range(per_group):
weight[g*per_group+i, i, 0, 0] = 1.0
return weight
# 1x1卷积核扩展为3x3
def pad_1x1_to_3x3_groupwise(w):
return F.pad(w, [1,1,1,1]) if w.shape[2] == 1 and w.shape[3] == 1 else w
C, groups = 8, 8 # 通道数,分组数
conv1 = nn.Conv2d(C, C, 1, padding=0, groups=groups, bias=True)
conv3 = nn.Conv2d(C, C, 3, padding=1, groups=groups, bias=True)
x = torch.rand(1, C, 32, 32)
# 直接前馈
y = conv1(x) + conv3(x) + x
# 融合前馈
w3, b3 = conv3.weight.data, conv3.bias.data
w1, b1 = pad_1x1_to_3x3_groupwise(conv1.weight.data), conv1.bias.data
w_id = pad_1x1_to_3x3_groupwise(get_identity_kernel_groupwise(C, groups))
b_id = torch.zeros(C)
w_fused = w3 + w1 + w_id
b_fused = b3 + b1 + b_id
conv_fused = nn.Conv2d(C, C, 3, padding=1, bias=True, groups=groups)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)
print('并行卷积融合最大误差:', (y-y_fused).abs().max().item()) #接近0,1e-7量级
串行卷积融合
将多个连续的卷积层融合为一个等效的卷积层。
实现原理
两个卷积核的串行相当于它们的卷积核进行卷积:
$$W_{fused} = W_2 * W_1$$
$$b_{fused} = W_2 * b_1 + b_2$$
其中 $*$ 表示卷积操作。
卷积时需要注意尺寸匹配,以及padding可能影响边界数值
示例代码(conv1升维-conv3特征提取-conv1降维)
def fuse_serial_3layer_groups(w1, b1, w2, b2, w3, b3, groups):
in_per_group = w1.shape[0] // groups
mid_per_group = w2.shape[0] // groups
out_per_group = w3.shape[0] // groups
ws, bs = [], []
# 支持分组卷积,组内进行标准卷积
for g in range(groups):
# 获取当前组的权重切片
w1g = w1[g*in_per_group: (g+1)*in_per_group]
b1g = b1[g*in_per_group: (g+1)*in_per_group] if b1 is not None else None
w2g = w2[g*mid_per_group: (g+1)*mid_per_group]
b2g = b2[g*mid_per_group: (g+1)*mid_per_group] if b2 is not None else None
w3g = w3[g*out_per_group: (g+1)*out_per_group]
b3g = b3[g*out_per_group: (g+1)*out_per_group] if b3 is not None else None
# 第一步:融合conv1和conv2
# w1(1x1)作为卷积核避免padding
w12 = F.conv2d(w2g, w1g.permute(1,0,2,3), bias=None)
b12 = ((w2g * b1g.reshape(1, -1, 1, 1)).sum((1,2,3)) if b1g is not None else 0) + (b2g if b2g is not None else 0)
# 第二步:融合conv12和conv3
# w3(1x1)作为卷积核避免padding
w123 = F.conv2d(w12.flip(2,3).permute(1,0,2,3), w3g, padding=0, stride=1).flip(2,3).permute(1,0,2,3)
b123 = ((w3g * b12.reshape(1,-1,1,1)).sum((1,2,3)) if isinstance(b12, torch.Tensor) else 0) + (b3g if b3g is not None else 0)
ws.append(w123)
bs.append(b123)
return torch.cat(ws, dim=0), torch.cat(bs, dim=0)
C, groups, gain = 8, 1, 3
conv1 = nn.Conv2d(C, gain*C, 1, padding=0, groups=groups, bias=True)
conv2 = nn.Conv2d(gain*C, gain*C, 3, padding=1, groups=groups, bias=True) #高维提取特征
conv3 = nn.Conv2d(gain*C, C, 1, padding=0, groups=groups, bias=True)
x = torch.rand(1, C, 32, 32)
# 直接前馈
y = conv3(conv2(conv1(x)))
# 融合前馈
w1, b1 = conv1.weight.data, conv1.bias.data
w2, b2 = conv2.weight.data, conv2.bias.data
w3, b3 = conv3.weight.data, conv3.bias.data
w_fused, b_fused = fuse_serial_3layer_groups(w1, b1, w2, b2, w3, b3, groups)
conv_fused = nn.Conv2d(C, C, 3, padding=1, bias=True, groups=groups)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)
print('串行卷积融合最大误差:', (y-y_fused).abs()[:,:,4:-4,4:-4].max().item()) # 避免边界padding影响