Residual Net论文笔记
残差网络(Residual Net, ResNet)自从2015年面世以来,凭借其优异的性能在ILSVRC中以绝对优势获得第一名,并成功应用于许多领域。
1. 传统深度网络的问题
x 25 = f 1 ( x ) \mathbf{x}_{25}=f_{1}(\mathbf{x}) x25=f1(x) o u t = f 2 ( x 25 ) \mathbf{out}=f_{2}(\mathbf{x}_{25}) out=f2(x25)
2. 残差结构和残差网络
2.1 残差是什么
如果存在某个k层的网络 F F F是当前最优的网络,那么可以构造一个更深的网络,其最后几层仅是网络f第k层输出的恒等映射,就可以取得与 F F F一致的结果
H ( x ) = F ( x ) + x F ( x ) = H ( x ) − x H(\mathbf{x})=F(\mathbf{x})+\mathbf{x} \\ F(\mathbf{x})=H(\mathbf{x})-\mathbf{x} H(x)=F(x)+xF(x)=H(x)−x
令 x x x成为恒等映射,那么只需要学习残差 F ( x ) F(x) F(x)作为非恒等映射。
残差在这里,指的是直接的映射H(x)与快捷连接x的差值,也就是 F ( x ) F(\mathbf{x}) F(x)。
2.2 残差模块 Residual Block
据此,我们设计一个残差模块(Residual Block)的结构如下:
y l = h ( x l ) + F ( x l , W l ) \mathbf{y}_{l}=h(\mathbf{x}_{l})+\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) yl=h(xl)+F(xl,Wl) x l + 1 = f ( y l ) \mathbf{x}_{l+1}=f(\mathbf{y}_{l}) xl+1=f(yl)
h ( x l ) = x l f = R e L U h(\mathbf{x}_{l})=\mathbf{x}_{l}\uad f=\mathrm{ReLU} h(xl)=xlf=ReLU x l + 1 ≡ y l \mathbf{x}_{l+1}\equiv\mathbf{y}_{l} xl+1≡yl
x l + 1 = x l + F ( x l , W l ) \mathbf{x}_{l+1}=\mathbf{x}_{l}+\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=xl+F(xl,Wl)
2.3 基本模块BasicBlock和BottleNeck
BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )
2.4 残差网络ResNet设计
2.4.1 恒等映射与残差的连接
- 给恒等映射 x \mathbf{x} x添加0,扩充其维度
- 用一个1×1的卷积进行下采样
3. Forward/Backward Propagation
3.1 Forward propogation
F ( x , w ) = x w F(x,w)=xw F(x,w)=xw
x L = F ( x L − 1 , w L − 1 ) = F ( F ( x L − 2 , w L − 2 ) , w L − 1 ) ⋯ = ∏ i = 1 L − 1 x i w i x_{L}=F(x_{L-1},w_{L-1})=F(F(x_{L-2},w_{L-2}),w_{L-1})\cdots=\prod_{i=1}^{L-1}{x_{i}w_{i}} xL=F(xL−1,wL−1)=F(F(xL−2,wL−2),wL−1)⋯=∏i=1L−1xiwi
x 2 = x 1 + F ( x 1 , w 1 ) x_{2}=x_{1}+F(x_1,w_{1}) x2=x1+F(x1,w1)
x 3 = x 2 + F ( x 2 , w 2 ) = x 1 + F ( x 1 , w 1 ) + F ( x 2 , w 2 ) x_{3}=x_{2}+F(x_2,w_{2})=x_{1}+F(x_{1},w_{1})+F(x_{2},w_{2}) x3=x2+F(x2,w2)=x1+F(x1,w1)+F(x2,w2)
⋯ \cdots ⋯
x L = x 1 + ∑ i = 1 L − 1 F ( x i , w i ) x_{L}=x_{1}+\sum_{i=1}^{L-1}{F(x_{i},w_{i})} xL=x1+∑i=1L−1F(xi,wi)
3.2 Back Propogation
浅层网络是 g ( x ) g(x) g(x),加入层以后变成 f ( g ( x ) ) f(g(x)) f(g(x))
∂ f ( g ( x ) ) ∂ x = ∂ f ( g ( x ) ) ∂ g ( x ) ∂ g ( x ) ∂ x \frac{\partial{f(g(x))}}{\partial{x}}=\frac{\partial{f(g(x))}}{\partial{
{g(x)}}}\frac{\partial{g(x)}}{\partial{x}} ∂x∂f(g(x))=∂g(x)∂f(g(x))∂x∂g(x)
∂ ( f ( g ( x ) ) + g ( x ) ) ∂ x = ∂ f ( g ( x ) ) ∂ g ( x ) ∂ g ( x ) ∂ x + ∂ g ( x ) ∂ x \frac{\partial{(f(g(x))+g(x))}}{\partial{x}}=\frac{\partial{f(g(x))}}{\partial{
{g(x)}}}\frac{\partial{g(x)}}{\partial{x}}+\frac{\partial{g(x)}}{\partial{x}} ∂x∂(f(g(x))+g(x))=∂g(x)∂f(g(x))∂x∂g(x)+∂x∂g(x)
损失函数对网络的第 l l l求梯度:
∂ L o s s ∂ x l = ∂ L o s s ∂ x L ∂ x L ∂ x l = ∂ L o s s ∂ x L ∂ ∏ i = 1 L − 1 x i w i ∂ x l \frac{\partial{Loss}}{\partial{x_{l}}}=\frac{\partial{Loss}}{\partial{x_{L}}}\frac{\partial{x_{L}}}{\partial{x_{l}}}=\frac{\partial{Loss}}{\partial{x_{L}}}\frac{\partial{\prod_{i=1}^{L-1}{x_{i}w_{i}}}}{\partial{x_{l}}} ∂xl∂Loss=∂xL∂Loss∂xl∂xL=∂xL∂Loss∂xl∂∏i=1L−1xiwi
∂ L o s s ∂ x l = ∂ L o s s ∂ x L ∂ x L ∂ x l = ∂ L o s s ∂ x L ( 1 + ∂ ∑ i = l L − 1 F ( x i , w i ) ∂ x l ) \frac{\partial{Loss}}{\partial{x_{l}}}=\frac{\partial{Loss}}{\partial{x_{L}}}\frac{\partial{x_{L}}}{\partial{x_{l}}}=\frac{\partial{Loss}}{\partial{x_{L}}}(1+\frac{\partial{\sum_{i=l}^{L-1}{F(x_{i},w_{i})}}}{\partial{x_{l}}}) ∂xl∂Loss=∂xL∂Loss∂xl∂xL=∂xL∂Loss(1+∂xl∂∑i=lL−1F(xi,wi))
4. 代码分析
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
- 第一种采用stride,在下采样是,设置stride=2,即卷积核每次移动的步长是2,使得输出的尺寸变小。
- 第二种采用dilation,即在每次采样的时候,卷积核各一个像素进行一次采样。https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md中有很生动的直观演示。
Basic Block
class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def forward(self, x): return self._forward_impl(x)
5. 恒等映射
对残差做一个简单的改进: x l + 1 = λ l x l + F ( x l , W l ) \mathbf{x}_{l+1}=\lambda_{l}\mathbf{x}_{l}+\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=λlxl+F(xl,Wl)
x L = ( ∏ i = l L − 1 λ i ) x l + ∑ i = l L − 1 F ( x l , W l ) \mathbf{x}_{L}=(\prod_{i=l}^{L-1}\lambda_{i})\mathbf{x}_{l}+\sum_{i=l}^{L-1}{\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l})} xL=(i=l∏L−1λi)xl+i=l∑L−1F(xl,Wl)
∂ L o s s ∂ x l = ∂ L o s s ∂ x L ∂ x L ∂ x l = ∂ L o s s ∂ x L ( ∏ i = l L − 1 λ i + ∂ ∑ i = l L − 1 F ( x l , W l ) ∂ x l ) \frac{\partial{Loss}}{\partial{\mathbf{x}_{l}}}=\frac{\partial{Loss}}{\partial{\mathbf{x}_{L}}}\frac{\partial{\mathbf{x}_{L}}}{\partial{\mathbf{x}_{l}}}=\frac{\partial{Loss}}{\partial{\mathbf{x}_{L}}}(\prod_{i=l}^{L-1}{\lambda_{i}}+\frac{\partial{\sum_{i=l}^{L-1}{\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l})}}}{\partial{\mathbf{x}_{l}}}) ∂xl∂Loss=∂xL∂Loss∂xl∂xL=∂xL∂Loss(i=l∏L−1λi+∂xl∂∑i=lL−1F(xl,Wl))
可以看到, λ \lambda λ大于1的时候,累乘会造成梯度爆炸;在小于1的时候,累乘会造成梯度消失。
6. 分析残差连接
x l + 1 = x l + F ( x l , W l ) \mathbf{x}_{l+1}=\mathbf{x}_{l}+\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=xl+F(xl,Wl)constant
x l + 1 = λ 1 x l + λ 2 F ( x l , W l ) \mathbf{x}_{l+1}=\lambda_{1}\mathbf{x}_{l}+\lambda_{2}\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=λ1xl+λ2F(xl,Wl)exclusive gating
x l + 1 = ( 1 − g ( x l ) ) x l + g ( x l ) F ( x l , W l ) \mathbf{x}_{l+1}=(1-g(\mathbf{x}_{l}))\mathbf{x}_{l}+g(\mathbf{x}_{l})\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=(1−g(xl))xl+g(xl)F(xl,Wl)shortcut-only gating
x l + 1 = ( 1 − g ( x l ) ) x l + F ( x l , W l ) \mathbf{x}_{l+1}=(1-g(\mathbf{x}_{l}))\mathbf{x}_{l}+\mathcal{F}(\mathbf{x}_{l},\mathcal{W}_{l}) xl+1=(1−g(xl))xl+F(xl,Wl)
其余还包括1×1 conv shortcut和dropout shortcut
可以看出,原版的残差连接时效果最好的。使用exclusive gate的效果则强烈依赖于偏差的设定。
7. 不同结构的残差模块
- 在(b)中,由于BN层的存在,使得 x l + 1 = f ( y l ) \mathbf{x}_{l+1}=f(\mathbf{y}_{l}) xl+1=f(yl)不再是一个线性映射,这会影响残差网络的性能
- 在(c)中,残差目标最后的ReLU激活层使得残差的输出范围是非负的。然而,无论是数学定义上还是经验上,残差的范围应该是 ( − ∞ , + ∞ ) (-\infty,+\infty) (−∞,+∞),非负的残差影响模型性能
- zai(d)和(e)中,作者采用了一种pre-activation的想法。
在原版设计中, f f f会影响到残差模块的两个部分:
y l + 1 = f ( y l ) + F ( f ( y l ) , W l + 1 ) \mathbf{y}_{l+1}=f(\mathbf{y}_{l})+\mathcal{F}(f(\mathbf{y}_{l}),\mathcal{W}_{l+1}) yl+1=f(yl)+F(f(yl),Wl+1)
pre-activation使得 f f f只影响残差部分,不影响恒等映射
y l + 1 = y l + F ^ ( f ( y l ) , W l + 1 ) \mathbf{y}_{l+1}=\mathbf{y}_{l}+\hat{\mathcal{F}}(f(\mathbf{y}_{l}),\mathcal{W}_{l+1}) yl+1=yl+F^(f(yl),Wl+1)
