Advertisement

pytorch调用预训练模型

阅读量:

1.pytorch提供以下模型

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1

2.模型结构加载

复制代码
 import torchvision.models as models

    
 resnet18 = models.resnet18(pretrained=False)
    
 #只加载模型结构
    
    
    
    
    python
    
    

3.模型结构和预训练参数加载

复制代码
 import torchvision.models as models

    
 resnet18 = models.resnet18(pretrained=True)
    
 #加载模型结构,同时加载预训练模型参数
    
    
    
    
    python
    
    

4.加载一部分预训练模型参数

模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:

复制代码
 net = UNet()

    
 vgg16 = models.vgg16(pretrained=True) #加载网络结构和预训练模型
    
 #static_dict()返回包含模块所有状态的字典
    
 pretrained_dict = vgg16.state_dict()  #返回内置预训练vgg模块的字典
    
 model_dict = net.state_dict()  #返回我们自己model的字典
    
  
    
 #------------------------最关键的三步------------------------------------------
    
 # 1. filter out unnecessary keys,也就是说从内置模块中删除掉我们不需要的字典
    
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    
  
    
 # 2. overwrite entries in the existing state dict,利用pretrained_dict更新现有的model_dict
    
 model_dict.update(pretrained_dict)
    
  
    
 # 3. load the new state dict,更新模型,加载我们真正需要的state_dict
    
 model.load_state_dict(model_dict)
    
  
    
 #也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-19/NQLqJFc98Z0y7utVSw6rMfnEhx2W.png)

5. 微调经典网络

复制代码
 import torchvision.models as models

    
 import torch.nn as nn
    
  
    
 vgg16 = models.vgg16(pretrained=True)
    
 vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)
    
 #这里相当于改变了vgg的第一层
    
    
    
    
    python
    
    

6.修改经典网络

这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。

先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:

复制代码
 def feature_layer():

    
     layers = []
    
     pool1 = ['4', '9', '16']
    
     pool2 = ['23', '30']
    
     vgg16 = models.vgg16(pretrained=True).features
    
     for name, layer in vgg16._modules.items():
    
     if isinstance(layer, nn.Conv2d):
    
         layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
    
     elif name in pool1:
    
         layers += [layer]
    
     elif name == pool2[0]:
    
         layers += [nn.MaxPool2d(2, 1, 1)]
    
     elif name == pool2[1]:
    
         layers += [nn.MaxPool2d(2, 1, 0)]
    
     else:
    
         continue
    
     features = nn.Sequential(*layers)
    
     #feat3 = features[0:24]
    
     return features
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-19/7yYfdlEhL680NKikBpVuXQmOZS4j.png)

大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:

复制代码
 class SNet(nn.Module):

    
     def __init__(self):
    
     super(SNet, self).__init__()
    
     self.features = feature_layer()
    
     def forward(self, x):
    
     x = self.features(x)
    
     return x
    
    
    
    
    python
    
    

参考链接

<>

<>

<>

全部评论 (0)

还没有任何评论哟~