Conv2d 与 BatchNorm2d的融合计算
May 14, 2020, 10:12 a.m.
read: 2654
感谢知乎大佬的文章,给了我深刻的启发
https://zhuanlan.zhihu.com/p/49329030
原理正如上面的文章所示,上面的文章没有给出Conv和BatchNorm的维度信息,以及一些细节信息。下面给出一些实现细节,适用于BatchNorm2d(Conv2d(x))这种情况,在部署阶段可以节省计算量。只适合部署阶段
参数导出
Conv2d有参数
weight 经过permute到这个形状 因为我一般用这个维度排列 (kernelSize x kernelSize x inputCh x outputCh)
bias (outputCh)
Norm2d有参数
running_var (outputCh)
running_mean (outputCh)
bn_weight (outputCh) 加上bn前置跟上面的CNN参数表示区分
bn_bias (outputCh)
融合
weight = bn_weight / np.sqrt(var + 1e-5) * weight
实际上 bn_weight的维度和weight的维度不一致,因此需要把bn_weight扩展出前三个为1的维度。上式就能计算下去。
bias = (bias - mean) / np.sqrt(var + 1e-5) * bn_weight + bn_bias
以新的weight和bias加入作为CNN的参数即可。