天天看點

pytorch之模型的儲存與加載

pytorch之模型的儲存與加載

如何訓練好的儲存模型

Save:

torch.save(model.state_dict(), PATH)
           

加載模型load

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
           

1. Save on GPU, Load on CPU

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
           

**

2. Save on GPU, Load on GPU

**

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
           

**

3. Save on CPU, Load on GPU

**

Save:

torch.save(model.state_dict(), PATH)
           

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
           

繼續閱讀