大家好,欢迎来到IT知识分享网。
目录
一、论文
先来看看VGG这篇论文《Very Deep Convolutional Networks for Large-Scale Image Recognition》论文下载地址
论文中几个模型主要以几下几种方案A、B、C、D、E。目前主要还是采用VGG16和VGG19也就是下图中的分别红框和绿框部分。
二、模型介绍
其实通过上面的表格就已经大致知道模型的框架组成部分了。其实VGG16与VGG19的区别就是前者在三、四、五部分少了一层卷积。这里先附基于pytorch的一些预训练模型预训练模型下载地址。
上图可以看出VGG分有无BatchNormalization。这里先介绍一下VGG16_bn的一些内部层结构。
VGG16_bn | |||
序号 | 层结构 | 层数 | 权重 |
0 | conv1-1 | 1 | 64x3x3 |
1 | batchnorm | ||
2 | relu1-1 | ||
3 | conv1-2 | 2 | 64x3x3 |
4 | batchnorm | ||
5 | relu1-2 | ||
6 | pool1 | ||
7 | conv2-1 | 3 | 128x3x3 |
8 | batchnorm | ||
9 | relu2-1 | ||
10 | conv2-2 | 4 | 128x3x3 |
11 | batchnorm | ||
12 | relu2-2 | ||
13 | pool2 | ||
14 | conv3-1 | 5 | 256x3x3 |
15 | batchnorm | ||
16 | relu3-1 | ||
17 | conv3-2 | 6 | 256x3x3 |
18 | batchnorm | ||
19 | relu3-2 | ||
20 | conv3-3 | 7 | 256x3x3 |
21 | batchnorm | ||
22 | relu3-3 | ||
23 | pool3 | ||
24 | conv4-1 | 8 | 512x3x3 |
25 | batchnorm | ||
26 | relu4-1 | ||
27 | conv4-2 | 9 | 512x3x3 |
28 | batchnorm | ||
29 | relu4-2 | ||
30 | conv4-3 | 10 | 512x3x3 |
31 | batchnorm | ||
32 | relu4-3 | ||
33 | pool4 | ||
34 | conv5-1 | 11 | 512x3x3 |
35 | batchnorm | ||
36 | relu5-1 | ||
37 | conv5-2 | 12 | 512x3x3 |
38 | batchnorm | ||
39 | relu5-2 | ||
40 | conv5-3 | 13 | 512x3x3 |
41 | batchnorm | ||
42 | relu5-3 | 512x3x3 | |
43 | pool5 | ||
44 | fc6(4096) | 14 | |
45 | relu6 | ||
46 | fc7(4096) | 15 | |
47 | relu7 | ||
48 | fc8(1000) | 16 | |
49 | prob(softmax) |
上表格是VGG16_bn的一些详细层结构,一共有16层(层是指卷积层和全连接层)VGG16则仅仅去掉红色部分的Batch_normalization部分。这里可以看到VGG16_bn的modules共有44个(这里不算全连接层),如果是VGG16则有31个(不算全连接层)。
下图是通过导入VGG16_bn模型在调试过程中的结果,可见与上面是一致。
三、模型预训练
3.1加载整个模型
基于pytorch模型预训练,首先都要导入加载模型。有两种方式,下面一一介绍。
1.采用在线下载,这种一般受网络原因比较慢,不建议。
2.自己先下好预训练模型,从本地加载,这里介绍一下加载预训练模型后后自己提供一张图片进行分类识别。
import torch
import numpy
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torchvision.models as models
vgg = models.vgg16_bn()
pre=torch.load('./vgg16_bn-6c64b313.pth')
vgg.load_state_dict(pre)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet數據集的均值
std=[0.229, 0.224, 0.225])
tran=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
im='./1.jpg'
im=Image.open(im)
im=tran(im)
im.unsqueeze_(dim=0)
print(im.shape)
# input()
out=vgg(im)
outnp=out.data[0]
ind=int(numpy.argmax(outnp))
print(ind)
from cls import d
print(d[ind])
print(out.shape)
# im.show()
3 主要有几个注意的地方。由于是加载VGG模型的,并提供自己一张图像进行预测,输入就必须符合VGG的格式。
- VGG模型的图像读入方式采用PIL库所以就得使用PIL库进行读入图片
- 输入图像的尺寸得必须和VGG保持一致224×224的三通道。(因为全连接层用的是VGG的)
- 上面采用的normalization归一化方式 几个固定的参数是因为VGG数据的分布,其均值和方差
- VGG的最终分类的类别是1000类,最终out=vgg(img)是一个1000元素的张量。
4 查看加載的參數
pre = torch.load('./pretrain/vgg16_bn-6c64b313.pth')
for key, v in pre.items():
print(key, v.size())
加載得到的是VGG網絡參數,可以將其輸出查看,這裏只顯示其size
features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])
上面的feature最多是42個,不是44個,因爲relu和pool沒有顯示出來,其分別是features.42 feature.43.因爲加載的參數pre裏面包含的內容是參數,而relu操作和池化操作是不需要參數的,也就是模型保存時並沒有保存下來。
3.2加载部分模型
class VGG(nn.Module):
def __init__(self, weights=False):
super(VGG, self).__init__()
if weights is False:
model = models.vgg19_bn(pretrained=True)
model = models.vgg19_bn(pretrained=False)
pre = torch.load(weights)
model.load_state_dict(pre)
self.vgg19 = model.features
for param in self.vgg19.parameters():
param.requires_grad = False
初始化有个参数权重,当为false时,默认网上下载VGG模型,通常网上下载的比较慢不建议,所以直接本地下载好之后再load即可。这里选择了vgg的features部分,全连接部分没有选择,当然也可以索引或者切片选择任何层的 features。
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/25322.html