天天看點

pytorch vgg16& resnet50簡單的遷移學習

import torch
import torch.nn as nn
from torch.utils.data import dataset,dataloader,Dataset,DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16

vgg = vgg16(pretrained=True)
#固定所有參數,requires_grad=False
for param in vgg.parameters():
    param.requires_grad = False

vgg.classifier[6] = nn.Sequential(nn.Linear(in_features=4096,out_features=22))

for p in vgg.parameters():
    print(p)
           

或者直接使用以下代碼隻更新最後一層參數:

for param in model.classifier[6].parameters():
    param.requires_grad = True
           

#resnet

resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():      # 當機參數
    param.requires_grad = False
resnet.fc = nn.Linear(resnet.fc.in_features, 100)