兩種方式
'''
第一種方式
模型整體儲存,占用空間會比較大
'''
torch.save(net, "../model/model.pkl")
torch.load("")
'''
第二種方式
儲存模型參數,在加載模型參數之前,必須先建立模型
'''
torch.save(net.state_dict(), "params.pkl")
net.load_state_dict(torch.load("path_of_model_state_dict"))
下面是整個流程的代碼
import torch
# data 資料加載
import numpy as np
import re # regular expression
ff = open("../housing.data").readlines()
data = []
for item in ff:
out = re.sub(r"\s{2,}", " ", item).strip() # .strip()可去掉字元串前後的空格
print(out)
data.append(out.split(" "))
data = np.array(data).astype(np.float)
# print(data.shape) # (506, 14)
Y = data[:, -1]
X = data[:, :-1]
'''
print(Y.shape)
print(X.shape)
'''
Y_train = Y[0:496]
X_train = X[0:496, ...]
Y_test = Y[496:]
X_test = X[496:, ...]
# print(Y_train.shape)
# print(X_train.shape)
# print(Y_test.shape)
# print(X_test.shape)
# net 網絡定義
class Net(torch.nn.Module):
def __init__(self, n_feature, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, 10)
self.predict = torch.nn.Linear(10, n_output)
def forward(self, x):
_out = self.hidden(x)
_out = torch.relu(_out)
_out = self.predict(_out)
return _out
net = Net(13, 1)
# loss 損失函數定義
loss_func = torch.nn.MSELoss()
# optimizer 優化器定義
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
# training 開始訓練
for i in range(10000):
x_data = torch.tensor(X_train, dtype=torch.float32)
y_data = torch.tensor(Y_train, dtype=torch.float32)
pred = net.forward(x_data)
pred = torch.squeeze(pred)
loss = loss_func(pred, y_data)*0.001
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("ite:{}, loss:{}".format(i, loss))
print(pred[0:10])
print(y_data[0:10])
# test 夾雜測試
x_data = torch.tensor(X_test, dtype=torch.float32)
y_data = torch.tensor(Y_test, dtype=torch.float32)
pred = net.forward(x_data)
pred = torch.squeeze(pred)
loss_test = loss_func(pred, y_data)*0.001
print("ite:{}, loss_test:{}".format(i, loss_test))
'''
模型儲存
第一種方式
模型整體儲存,占用空間會比較大
'''
# torch.save(net, "../model/model.pkl")
# torch.load("")
'''
第二種方式
儲存模型參數,在加載模型參數之前,必須先建立模型
'''
# torch.save(net.state_dict(), "params.pkl")
# net.load_state_dict()