轻量化设计之深度卷积和结构重参数


脑图可拖动和放缩

在深度学习领域,卷积神经网络(CNN)一直是计算机视觉任务的主流架构。近年来,随着对模型效率要求的提高,各种卷积变体和优化技术层出不穷。本文将深入探讨标准卷积、深度卷积以及结构重参数化技术,特别是其中的卷积-BN融合、并行卷积融合和串行卷积融合技术。

Github地址

标准卷积和深度卷积(DWConv)

标准卷积和深度卷积不做过多介绍,示意图如下

结构重参数(Structural Re-parameterization)

结构重参数化(Structural Re-parameterization)是一种训练-推理解耦的技术,在训练时使用复杂的结构,在推理时转换为简单的结构,既保证了训练时的性能,又提高了推理效率。

经典论文是CVPR2021 RepVGG提出的并行重参数结构,示意图如下

训练阶段,每个block包括$3\times3$,$1\times1$和一个恒等映射,按照论文结论,BN的引入可以增加模型训练时候的拟合能力

测试阶段,串行的BN和卷积可以合并,并行的不同卷积可以统一合并成$3\times3$的标准卷积

对于测试来说,结构重参数是一种“免费的午餐”,可以无损提升模型性能【丁博士知乎解答

卷积-BN融合

将卷积层和后续的BN层融合为一个卷积层,减少推理时的计算量。

融合原理

BN层的计算:

$$y_{bn} = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

融合后的卷积权重和偏置:
$$W_{fused} = \gamma \cdot W / \sqrt{\sigma^2 + \epsilon}$$
$$b_{fused} = \gamma \cdot (b - \mu) / \sqrt{\sigma^2 + \epsilon} + \beta$$

示例代码
def fuse_conv_bn(conv, bn):
    w = conv.weight
    b = conv.bias if conv.bias is not None else torch.zeros(w.size(0)).to(w.device)
    
    gamma, beta = bn.weight, bn.bias
    mean, var = bn.running_mean, bn.running_var
    std = (var + bn.eps).sqrt()
    
    # 融合公式推导
    w_fused = w * (gamma / std).reshape(-1, 1, 1, 1)
    b_fused = beta + (gamma / std) * (b - mean)
    
    return w_fused, b_fused
  
# 验证融合前后等价
conv = nn.Conv2d(3, 16, 3, padding=1, bias=True)
bn = nn.BatchNorm2d(16)
conv.eval()
bn.eval()
x = torch.rand(1, 3, 32, 32)

# 直接前馈
y = bn(conv(x))

# 融合前馈
w_fused, b_fused = fuse_conv_bn(conv, bn)
conv_fused = nn.Conv2d(3, 16, 3, padding=1, bias=True)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)

print('卷积BN融合最大误差:', (y-y_fused).abs().max().item())    #接近0,1e-7量级 

并行卷积融合

将多个并行分支的卷积融合为一个卷积。

实现原理

多个卷积核可以相加合并:
$$W_{fused} = \sum_i W_i$$
$$b_{fused} = \sum_i b_i$$

公式中默认卷积核尺寸相同,不同则需padding至相同尺寸

示例代码(RepVGG结构:conv1 + conv3 + identity)
# 创建恒等映射核
def get_identity_kernel_groupwise(channels, groups):
    per_group = channels // groups
    # 支持分组卷积,组内进行标准卷积
    weight = torch.zeros((channels, per_group, 1, 1))
    for g in range(groups):
        for i in range(per_group):
            weight[g*per_group+i, i, 0, 0] = 1.0
    return weight

# 1x1卷积核扩展为3x3
def pad_1x1_to_3x3_groupwise(w):
    return F.pad(w, [1,1,1,1]) if w.shape[2] == 1 and w.shape[3] == 1 else w

C, groups = 8, 8   # 通道数,分组数
conv1 = nn.Conv2d(C, C, 1, padding=0, groups=groups, bias=True)
conv3 = nn.Conv2d(C, C, 3, padding=1, groups=groups, bias=True)
x = torch.rand(1, C, 32, 32)

# 直接前馈
y = conv1(x) + conv3(x) + x

# 融合前馈
w3, b3 = conv3.weight.data, conv3.bias.data
w1, b1 = pad_1x1_to_3x3_groupwise(conv1.weight.data), conv1.bias.data
w_id = pad_1x1_to_3x3_groupwise(get_identity_kernel_groupwise(C, groups))
b_id = torch.zeros(C)
w_fused = w3 + w1 + w_id
b_fused = b3 + b1 + b_id
conv_fused = nn.Conv2d(C, C, 3, padding=1, bias=True, groups=groups)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)
print('并行卷积融合最大误差:', (y-y_fused).abs().max().item())    #接近0,1e-7量级 

串行卷积融合

将多个连续的卷积层融合为一个等效的卷积层。

实现原理

两个卷积核的串行相当于它们的卷积核进行卷积:
$$W_{fused} = W_2 * W_1$$
$$b_{fused} = W_2 * b_1 + b_2$$

其中 $*$ 表示卷积操作。

卷积时需要注意尺寸匹配,以及padding可能影响边界数值

示例代码(conv1升维-conv3特征提取-conv1降维)
def fuse_serial_3layer_groups(w1, b1, w2, b2, w3, b3, groups):
    in_per_group = w1.shape[0] // groups
    mid_per_group = w2.shape[0] // groups
    out_per_group = w3.shape[0] // groups
    
    ws, bs = [], []
    # 支持分组卷积,组内进行标准卷积
    for g in range(groups):
        # 获取当前组的权重切片
        w1g = w1[g*in_per_group: (g+1)*in_per_group]
        b1g = b1[g*in_per_group: (g+1)*in_per_group] if b1 is not None else None
        w2g = w2[g*mid_per_group: (g+1)*mid_per_group]
        b2g = b2[g*mid_per_group: (g+1)*mid_per_group] if b2 is not None else None
        w3g = w3[g*out_per_group: (g+1)*out_per_group]
        b3g = b3[g*out_per_group: (g+1)*out_per_group] if b3 is not None else None
        
        # 第一步:融合conv1和conv2
        # w1(1x1)作为卷积核避免padding
        w12 = F.conv2d(w2g, w1g.permute(1,0,2,3), bias=None)  
        b12 = ((w2g * b1g.reshape(1, -1, 1, 1)).sum((1,2,3)) if b1g is not None else 0) + (b2g if b2g is not None else 0)
        
        # 第二步:融合conv12和conv3
        # w3(1x1)作为卷积核避免padding
        w123 = F.conv2d(w12.flip(2,3).permute(1,0,2,3), w3g, padding=0, stride=1).flip(2,3).permute(1,0,2,3)
        b123 = ((w3g * b12.reshape(1,-1,1,1)).sum((1,2,3)) if isinstance(b12, torch.Tensor) else 0) + (b3g if b3g is not None else 0)
        
        ws.append(w123)
        bs.append(b123)
    
    return torch.cat(ws, dim=0), torch.cat(bs, dim=0)
  
C, groups, gain = 8, 1, 3  
conv1 = nn.Conv2d(C, gain*C, 1, padding=0, groups=groups, bias=True)
conv2 = nn.Conv2d(gain*C, gain*C, 3, padding=1, groups=groups, bias=True)   #高维提取特征
conv3 = nn.Conv2d(gain*C, C, 1, padding=0, groups=groups, bias=True)
x = torch.rand(1, C, 32, 32)

# 直接前馈
y = conv3(conv2(conv1(x)))

# 融合前馈
w1, b1 = conv1.weight.data, conv1.bias.data
w2, b2 = conv2.weight.data, conv2.bias.data
w3, b3 = conv3.weight.data, conv3.bias.data
w_fused, b_fused = fuse_serial_3layer_groups(w1, b1, w2, b2, w3, b3, groups)
conv_fused = nn.Conv2d(C, C, 3, padding=1, bias=True, groups=groups)
conv_fused.weight.data, conv_fused.bias.data = w_fused, b_fused
y_fused = conv_fused(x)
print('串行卷积融合最大误差:', (y-y_fused).abs()[:,:,4:-4,4:-4].max().item())    # 避免边界padding影响

  目录
}