pytorch的模型和參數是分開的,可以分别儲存或加載模型和參數。
1、直接儲存模型
# 儲存模型
torch.save(model, 'model.pth')
# 加載模型
model = torch.load('model.pth')
2、分别加載模型的結構和參數
# 儲存模型參數
torch.save(model.state_dict(), 'model.pth')
# 加載模型參數
model.load_state_dict(torch.load('model.pth')
CPU模型加載GPU參數
通過DataParalle使用多GPU
model=DataParalle(model)
#儲存參數
torch.save(model.module.state_dict(), 'model.pth')
自己習慣用的代碼段
自己習慣用的代碼段
# 判斷gpu是否可用
use_cuda = torch.cuda.is_available()
# 是否使用多gpu
use_multi_gpu = True
# 預設加載的cpu的參數
model.load_state_dict(torch.load('model.pth')
if use_cuda:
model = model.cuda()
if use_multi_gpu:
model = DataParalle(model)
# 儲存模型參數(一般儲存cpu的參數比較好)
if use_multi_gpu:
torch.save(model.cpu().module.state_dict(), 'model.pth')
else:
torch.save(model.cpu().state_dict(), 'model.pth')
3、pytorch預訓練模型
加載預訓練模型和參數
隻加載模型,不加載預訓練參數
# 加載模型
resnet18 = models.resnet18(pretrained=False)
# 加載預先下載下傳好的預訓練模型參數
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
加載部分預訓練模型
resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加載torchvision中的預訓練模型和參數後通過state_dict()方法提取參數
也可以直接從官方model_zoo下載下傳:
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict裡不屬于model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
model.load_state_dict(model_dict)
【參考資料】
1、PyTorch學習:加載模型和參數
2、PyTorch使用cpu調用gpu訓練的模型