天天看點

Pytorch 模型的加載與儲存

pytorch的模型和參數是分開的,可以分别儲存或加載模型和參數。

1、直接儲存模型

# 儲存模型
torch.save(model, 'model.pth')
# 加載模型
model = torch.load('model.pth')
           

2、分别加載模型的結構和參數

# 儲存模型參數
torch.save(model.state_dict(), 'model.pth')
# 加載模型參數
model.load_state_dict(torch.load('model.pth')
           

CPU模型加載GPU參數

通過DataParalle使用多GPU

model=DataParalle(model)
#儲存參數
torch.save(model.module.state_dict(), 'model.pth')
           

自己習慣用的代碼段

# 判斷gpu是否可用
use_cuda = torch.cuda.is_available()
# 是否使用多gpu
use_multi_gpu = True
# 預設加載的cpu的參數
model.load_state_dict(torch.load('model.pth')

if use_cuda:
	model = model.cuda()
if use_multi_gpu:
	model = DataParalle(model)

# 儲存模型參數(一般儲存cpu的參數比較好)
if use_multi_gpu:
	torch.save(model.cpu().module.state_dict(), 'model.pth')
else:
	torch.save(model.cpu().state_dict(), 'model.pth')
           

3、pytorch預訓練模型

加載預訓練模型和參數

隻加載模型,不加載預訓練參數

# 加載模型
resnet18 = models.resnet18(pretrained=False)
# 加載預先下載下傳好的預訓練模型參數
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
           

加載部分預訓練模型

resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加載torchvision中的預訓練模型和參數後通過state_dict()方法提取參數
   也可以直接從官方model_zoo下載下傳:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict裡不屬于model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
model.load_state_dict(model_dict)
           

【參考資料】

1、PyTorch學習:加載模型和參數

2、PyTorch使用cpu調用gpu訓練的模型

繼續閱讀