天天看點

PyTorch:采用sklearn 工具生成這樣的合成資料集+利用PyTorch實作簡單合成資料集上的線性回歸進行資料分析

輸出結果

PyTorch:采用sklearn 工具生成這樣的合成資料集+利用PyTorch實作簡單合成資料集上的線性回歸進行資料分析

核心代碼

#PyTorch:采用sklearn 工具生成這樣的合成資料集+利用PyTorch實作簡單合成資料集上的線性回歸進行資料分析

from sklearn.datasets import make_regression

import seaborn as sns

import pandas as pd

import matplotlib.pyplot as plt

sns.set()

x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)

df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})

sns.lmplot(x='X', y='Y', data=df, fit_reg=True)

plt.show()

x_torch = torch.FloatTensor(x_train)

y_torch = torch.FloatTensor(y_train)

y_torch = y_torch.view(y_torch.size()[0], 1)

class LinearRegression(torch.nn.Module):  #定義LR的類。torch.nn庫構模組化型

   #PyTorch 的 nn 庫中有大量有用的子產品,其中一個就是線性子產品。如名字所示,它對輸入執行線性變換,即線性回歸。

   def __init__(self, input_size, output_size):

       super(LinearRegression, self).__init__()

       self.linear = torch.nn.Linear(input_size, output_size)  

   def forward(self, x):

       return self.linear(x)

model = LinearRegression(1, 1)

criterion = torch.nn.MSELoss() #訓練線性回歸,我們需要從 nn 庫中添加合适的損失函數。對于線性回歸,我們将使用 MSELoss()——均方差損失函數

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#還需要使用優化函數(SGD),并運作與之前示例類似的反向傳播。本質上,我們重複上文定義的 train() 函數中的步驟。

#不能直接使用該函數的原因是我們實作它的目的是分類而不是回歸,以及我們使用交叉熵損失和最大元素的索引作為模型預測。而對于線性回歸,我們使用線性層的輸出作為預測。

for epoch in range(50):

   data, target = Variable(x_torch), Variable(y_torch)

   output = model(data)

   optimizer.zero_grad()

   loss = criterion(output, target)

   loss.backward()

   optimizer.step()

predicted = model(Variable(x_torch)).data.numpy()

#列印出原始資料和适合 PyTorch 的線性回歸

plt.plot(x_train, y_train, 'o', label='Original data')

plt.plot(x_train, predicted, label='Fitted line')

plt.legend()

plt.title(u'Py:PyTorch實作簡單合成資料集上的線性回歸進行資料分析')