文章目录
-
- Saving and Loading Models
-
- Before
- 关于 state_dict
- 保存/加载 state_dict
- 保存/加载 整个模型
- 保存/加载 Checkpoint 以及恢复训练
- 使用来自不同模型的参数 Warmstarting 模型
- 跨 GPU 和 CPU 保存和加载模型
Saving and Loading Models
参考翻译 SAVING AND LOADING MODELS
Before
三个核心函数:
-
torch.save
将序列化对象保存到磁盘
该函数使用
pickle
进行序列化
包括 models,tensors,dictionaries 等所有类型对象都可以使用该函数保存
-
反序列化加载到内存torch.load
-
使用反序列化的torch.nn.Module.load_state_dict
加载模型参数字典state_dict
关于 state_dict
torch.nn.Module
的可学习化参数(learnable parameters, eg. weights 和 bias) 都包含在模型的参数中(保存在
model.parameters()
)
state_dict
只是一个
Python dictionary
对象,它将每个层映射到它的参数张量
需要注意的是,只有拥有可学习(learnable parameters)参数的神经网络层(eg. convolutional layers, linear layers)和注册的缓存(batchnorm的running_mean)才在
state_dict
中有条目。
优化器对象(
torch.optim
)也有
state_dict
,其中包含优化器状态的相关信息以及使用的超参数。
因为
state_dict
是
Python dictionaries
,所以它们可以很容易的保存,替换,更新,添加。
print("Model's state_dict")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Optimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
保存/加载 state_dict
Save:
Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
在预测之前需要调用来设置
model.eval()
和
dropout
为
batch normalization layers
evaluation
模式。
一般使用
或
.pt
后缀保存模型文件
pth
保存/加载 整个模型
Save:
Load:
model = torch.load(PATH)
model.eval()
使用该方式保存模型,缺点是序列化的数据依赖于特定的类和额外的数据结构。而
pickle
不保存模型的类本身,而是保存包含这个类的文件的位置。
因此,如果进行代码重构的话,会出现问题。
保存/加载 Checkpoint 以及恢复训练
在训练中,每隔 M 个 epoch 保存一次模型,避免训练中断,以恢复模型
Save:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict:" optimizer.state_dict(),
"loss": loss,
...
}, PATH)
Load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
model.eval()
# or train after break
modle.train()
一般使用 .tar
文件后缀保存 checkpoint 文件
使用来自不同模型的参数 Warmstarting 模型
Save:
Load:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
迁移学习
设置
strict=False
来忽略不匹配的键值
除此以外,也可以只加载某些匹配的神经网络层的参数
跨 GPU 和 CPU 保存和加载模型
-
在 GPU 上保存,在 CPU 上加载
Save:
Load:
device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH), map_location=device)
在 CPU 上调用 GPU 上训练的模型,传递参数 map_location=device
,则重新将张量动态的映射到 CPU 上
-
在 GPU 上保存,在 GPU 上调用
Save:
Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
注意,这里需要将模型输入的其他张量
调用
调用会返回一个在 GPU 上
input = input.to(device)
input
的新的拷贝
而不会重写
,所以需要重新赋值
input