天天看點

torch模型的儲存與加載

兩種方式

'''
    第一種方式
    模型整體儲存,占用空間會比較大
'''
torch.save(net, "../model/model.pkl")
torch.load("")

'''
    第二種方式
    儲存模型參數,在加載模型參數之前,必須先建立模型
'''
torch.save(net.state_dict(), "params.pkl")
net.load_state_dict(torch.load("path_of_model_state_dict"))
           

下面是整個流程的代碼

import torch

# data 資料加載
import numpy as np
import re  # regular expression

ff = open("../housing.data").readlines()
data = []
for item in ff:
    out = re.sub(r"\s{2,}", " ", item).strip()  # .strip()可去掉字元串前後的空格
    print(out)
    data.append(out.split(" "))

data = np.array(data).astype(np.float)
# print(data.shape)  # (506, 14)

Y = data[:, -1]
X = data[:, :-1]
'''
print(Y.shape)
print(X.shape)
'''

Y_train = Y[0:496]
X_train = X[0:496, ...]

Y_test = Y[496:]
X_test = X[496:, ...]
# print(Y_train.shape)
# print(X_train.shape)
# print(Y_test.shape)
# print(X_test.shape)

# net 網絡定義
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, 10)
        self.predict = torch.nn.Linear(10, n_output)

    def forward(self, x):
        _out = self.hidden(x)
        _out = torch.relu(_out)
        _out = self.predict(_out)
        return _out


net = Net(13, 1)

# loss 損失函數定義
loss_func = torch.nn.MSELoss()

# optimizer 優化器定義
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

# training  開始訓練
for i in range(10000):
    x_data = torch.tensor(X_train, dtype=torch.float32)
    y_data = torch.tensor(Y_train, dtype=torch.float32)
    pred = net.forward(x_data)
    pred = torch.squeeze(pred)
    loss = loss_func(pred, y_data)*0.001
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("ite:{}, loss:{}".format(i, loss))
    print(pred[0:10])
    print(y_data[0:10])
# test 夾雜測試
    x_data = torch.tensor(X_test, dtype=torch.float32)
    y_data = torch.tensor(Y_test, dtype=torch.float32)
    pred = net.forward(x_data)
    pred = torch.squeeze(pred)
    loss_test = loss_func(pred, y_data)*0.001
    print("ite:{}, loss_test:{}".format(i, loss_test))

'''
	模型儲存
    第一種方式
    模型整體儲存,占用空間會比較大
'''
# torch.save(net, "../model/model.pkl")
# torch.load("")

'''
    第二種方式
    儲存模型參數,在加載模型參數之前,必須先建立模型
'''
# torch.save(net.state_dict(), "params.pkl")
# net.load_state_dict()
           

繼續閱讀