天天看点

PyTorch笔记之模型保存和加载

文章目录

    • Saving and Loading Models
      • Before
      • 关于 state_dict
      • 保存/加载 state_dict
      • 保存/加载 整个模型
      • 保存/加载 Checkpoint 以及恢复训练
      • 使用来自不同模型的参数 Warmstarting 模型
      • 跨 GPU 和 CPU 保存和加载模型

Saving and Loading Models

参考翻译 SAVING AND LOADING MODELS

Before

三个核心函数:

  • torch.save

    将序列化对象保存到磁盘

    该函数使用

    pickle

    进行序列化

    包括 models,tensors,dictionaries 等所有类型对象都可以使用该函数保存

  • torch.load

    反序列化加载到内存
  • torch.nn.Module.load_state_dict

    使用反序列化的

    state_dict

    加载模型参数字典

关于 state_dict

torch.nn.Module

的可学习化参数(learnable parameters, eg. weights 和 bias) 都包含在模型的参数中(保存在

model.parameters()

state_dict

只是一个

Python dictionary

对象,它将每个层映射到它的参数张量

需要注意的是,只有拥有可学习(learnable parameters)参数的神经网络层(eg. convolutional layers, linear layers)和注册的缓存(batchnorm的running_mean)才在

state_dict

中有条目。

优化器对象(

torch.optim

)也有

state_dict

,其中包含优化器状态的相关信息以及使用的超参数。

因为

state_dict

Python dictionaries

,所以它们可以很容易的保存,替换,更新,添加。

print("Model's state_dict")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Optimizer's state_dict")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
           

保存/加载 state_dict

Save:

Load:

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))
model.eval()
           
在预测之前需要调用

model.eval()

来设置

dropout

batch normalization layers

evaluation

模式。

一般使用

.pt

pth

后缀保存模型文件

保存/加载 整个模型

Save:

Load:

model = torch.load(PATH)
model.eval()
           
使用该方式保存模型,缺点是序列化的数据依赖于特定的类和额外的数据结构。而

pickle

不保存模型的类本身,而是保存包含这个类的文件的位置。

因此,如果进行代码重构的话,会出现问题。

保存/加载 Checkpoint 以及恢复训练

在训练中,每隔 M 个 epoch 保存一次模型,避免训练中断,以恢复模型

Save:

torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict:" optimizer.state_dict(),
    "loss": loss,
    ...
    }, PATH)
           

Load:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()
# or train after break
modle.train()
           
一般使用

.tar

文件后缀保存 checkpoint 文件

使用来自不同模型的参数 Warmstarting 模型

Save:

Load:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
           

迁移学习

设置

strict=False

来忽略不匹配的键值

除此以外,也可以只加载某些匹配的神经网络层的参数

跨 GPU 和 CPU 保存和加载模型

  • 在 GPU 上保存,在 CPU 上加载

    Save:

Load:

device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH), map_location=device)
           
在 CPU 上调用 GPU 上训练的模型,传递参数

map_location=device

,则重新将张量动态的映射到 CPU 上
  • 在 GPU 上保存,在 GPU 上调用

    Save:

Load:

device = torch.device("cuda")

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))
model.to(device)
           

注意,这里需要将模型输入的其他张量

调用

input = input.to(device)

调用会返回一个在 GPU 上

input

的新的拷贝

而不会重写

input

,所以需要重新赋值

继续阅读