天天看点

深度学习--第15篇: Pytorch保存和加载模型参数参考博客1. 保存模型和参数2. 仅保存参数3. 加载pytorch预训练模型

Pytorch保存和加载模型参数

  • 参考博客
  • 1. 保存模型和参数
  • 2. 仅保存参数
  • 3. 加载pytorch预训练模型
    • 3.1 加载预训练模型和参数
    • 3.2 只加载模型不加载预训练参数

参考博客

参考博客: https://blog.csdn.net/lscelory/article/details/81482586

pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。

pytorch有两种模型保存方式:

  • 保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net
  • 只保存神经网络的训练模型参数,save的对象是net.state_dict()

对应两种保存模型的方式,pytorch也有两种加载模型的方式。对应第一种保存方式,加载模型时通过torch.load(’.pth’)直接初始化新的神经网络对象;对应第二种保存方式,需要首先导入对应的网络,再通过net.load_state_dict(torch.load(’.pth’))完成模型参数的加载。

在网络比较大的时候,第一种方法会花费较多的时间。

1. 保存模型和参数

  • 保存模型
# 将网络结构和模型参数都保存起来,在测试时可以直接加载,不需要初始化网络结构
torch.save(model, path)

参数:
	model: 训练的网络
	pth: 保存的路径(包含文件名,后缀名以.pth .pkl 等结尾)

实例:
torch.save(model, os.path.join('.', 'lenet.pth')) # 保存模型结构和参数
           
  • 加载模型
# 直接加载模型文件,不需要初始化网络结构
model = torch.load(path)

参数:
	model: 加载后的网络
	pth: 保存模型文件的路径(包含文件名,后缀名以.pth .pkl 等结尾)

实例:
model = torch.load(os.path.join('.', 'lenet.pth'))# 加载模型结构和参数
           

2. 仅保存参数

  • 保存参数
# 将lenet模型储存为lenet.pth, 注意保存的仅仅是网络模型的状态信息参数字典, 加载是需要初始化网络模型
torch.save(net.state_dict(), os.path.join('.', 'lenet.pth'))
           
  • 加载参数
# 加载lenet,模型存放在lenet.pth, 加载之前要确认网络模型已初始化完成
model = torch.load(os.path.join('.', 'lenet.pth'))
net.load_state_dict(model)
           

3. 加载pytorch预训练模型

3.1 加载预训练模型和参数

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=True) # 加载预训练模型AlexNet和参数

resnet18 = torchvision.models.resnet18(pretrained=True)
           

3.2 只加载模型不加载预训练参数

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=False) # 加载预训练模型AlexNet

# 导入模型结构
ResNet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练参数到resnet18
ResNet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
           

继续阅读