天天看点

pytorch之模型的保存与加载

pytorch之模型的保存与加载

如何训练好的保存模型

Save:

torch.save(model.state_dict(), PATH)
           

加载模型load

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
           

1. Save on GPU, Load on CPU

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
           

**

2. Save on GPU, Load on GPU

**

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
           

**

3. Save on CPU, Load on GPU

**

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
           

继续阅读