【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]

【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]目录一、论文二、模型介绍三、模型预训练一、论文先来看看VGG这篇论文《VeryDeepConvolutionalNetworksforLarge-ScaleImageRecognition》论文下载地址论文中几个模型主要以几下几种方案A、B、C、D、E。目前主要还是采用VGG16和VGG19也就是下图中的分别红框和绿框部分。  二、模型介绍其实通过…

大家好,欢迎来到IT知识分享网。

目录

一、论文

二、模型介绍

三、模型预训练


一、论文

先来看看VGG这篇论文《Very Deep Convolutional Networks for Large-Scale Image Recognition》论文下载地址

论文中几个模型主要以几下几种方案A、B、C、D、E。目前主要还是采用VGG16和VGG19也就是下图中的分别红框绿框部分。 

 【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]

二、模型介绍

其实通过上面的表格就已经大致知道模型的框架组成部分了。其实VGG16与VGG19的区别就是前者在三、四、五部分少了一层卷积。这里先附基于pytorch的一些预训练模型预训练模型下载地址

【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]

上图可以看出VGG分有无BatchNormalization。这里先介绍一下VGG16_bn的一些内部层结构。

VGG16_bn
序号 层结构 层数 权重
0 conv1-1 1 64x3x3
1 batchnorm  
2 relu1-1  
3 conv1-2 2 64x3x3
4 batchnorm  
5 relu1-2  
6 pool1  
7 conv2-1 3 128x3x3
8 batchnorm  
9 relu2-1  
10 conv2-2 4 128x3x3
11 batchnorm  
12 relu2-2  
13 pool2  
14 conv3-1 5 256x3x3
15 batchnorm  
16 relu3-1  
17 conv3-2 6 256x3x3
18 batchnorm  
19 relu3-2  
20 conv3-3 7 256x3x3
21 batchnorm  
22 relu3-3  
23 pool3  
24 conv4-1 8 512x3x3
25 batchnorm  
26 relu4-1  
27 conv4-2 9 512x3x3
28 batchnorm  
29 relu4-2  
30 conv4-3 10 512x3x3
31 batchnorm  
32 relu4-3  
33 pool4  
34 conv5-1 11 512x3x3
35 batchnorm  
36 relu5-1  
37 conv5-2 12 512x3x3
38 batchnorm  
39 relu5-2  
40 conv5-3 13 512x3x3
41 batchnorm  
42 relu5-3 512x3x3
43 pool5  
44 fc6(4096) 14  
45 relu6    
46 fc7(4096) 15  
47 relu7    
48 fc8(1000) 16  
49 prob(softmax)    

 上表格是VGG16_bn的一些详细层结构,一共有16层(层是指卷积层和全连接层)VGG16则仅仅去掉红色部分的Batch_normalization部分。这里可以看到VGG16_bn的modules共有44个(这里不算全连接层),如果是VGG16则有31个(不算全连接层)。

下图是通过导入VGG16_bn模型在调试过程中的结果,可见与上面是一致。

【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]

【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练[亲测有效]

三、模型预训练

3.1加载整个模型

基于pytorch模型预训练,首先都要导入加载模型。有两种方式,下面一一介绍。

1.采用在线下载,这种一般受网络原因比较慢,不建议。

2.自己先下好预训练模型,从本地加载,这里介绍一下加载预训练模型后后自己提供一张图片进行分类识别。

import torch
import numpy
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torchvision.models as models

vgg = models.vgg16_bn()
pre=torch.load('./vgg16_bn-6c64b313.pth')
vgg.load_state_dict(pre)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet數據集的均值
                                 std=[0.229, 0.224, 0.225])

tran=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])


im='./1.jpg'
im=Image.open(im)
im=tran(im)
im.unsqueeze_(dim=0)
print(im.shape)
# input()
out=vgg(im)
outnp=out.data[0]
ind=int(numpy.argmax(outnp))
print(ind)

from cls import d
print(d[ind])


print(out.shape)


# im.show()

3 主要有几个注意的地方。由于是加载VGG模型的,并提供自己一张图像进行预测,输入就必须符合VGG的格式。

  • VGG模型的图像读入方式采用PIL库所以就得使用PIL库进行读入图片
  • 输入图像的尺寸得必须和VGG保持一致224×224的三通道。(因为全连接层用的是VGG的)
  • 上面采用的normalization归一化方式 几个固定的参数是因为VGG数据的分布,其均值和方差
  • VGG的最终分类的类别是1000类,最终out=vgg(img)是一个1000元素的张量。

 4 查看加載的參數

pre = torch.load('./pretrain/vgg16_bn-6c64b313.pth')
    for key, v in pre.items():
        print(key, v.size())

加載得到的是VGG網絡參數,可以將其輸出查看,這裏只顯示其size

features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])

上面的feature最多是42個,不是44個,因爲relu和pool沒有顯示出來,其分別是features.42  feature.43.因爲加載的參數pre裏面包含的內容是參數,而relu操作和池化操作是不需要參數的,也就是模型保存時並沒有保存下來。

3.2加载部分模型

class VGG(nn.Module):
    def __init__(self, weights=False):
        super(VGG, self).__init__()
        if weights is False:
            model = models.vgg19_bn(pretrained=True)

        model = models.vgg19_bn(pretrained=False)
        pre = torch.load(weights)
        model.load_state_dict(pre)
        self.vgg19 = model.features
        
        for param in self.vgg19.parameters():
            param.requires_grad = False

初始化有个参数权重,当为false时,默认网上下载VGG模型,通常网上下载的比较慢不建议,所以直接本地下载好之后再load即可。这里选择了vgg的features部分,全连接部分没有选择,当然也可以索引或者切片选择任何层的 features。

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/25322.html

(0)

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信