天天看點

利用TensorFlow解決線性回歸問題

利用TensorFlow解決線性回歸問題

1.導入必要的庫

import tensorflow as tf

在之前的基礎上,還需要導入TensorFlow的庫。

2.建立一個訓練函數

def train_tf(train_data):

# 1.擷取資料

trainx = [train_d[0] for train_d in train_data] # list

trainy = [train_d[1] for train_d in train_data]

# 2.構造預測的線性回歸函數:y= W * x + b
   W = tf.Variable(tf.random_uniform([1]))  # 從均勻分布中傳回随機值,即[0,1)
   b = tf.Variable(tf.zeros([1]))  # 在一維數組裡放一個值
   y = W * trainx + b

   # 3.判斷假設的函數的好壞
   cost = tf.reduce_mean(tf.square(y - trainy))

   # 4.優化函數
   optimizer = tf.train.AdamOptimizer(0.05)
   train = optimizer.minimize(cost)

   # 5.開始訓練
   with tf.Session() as sess:
         # 初始化所有變量值
         sess.run(tf.global_variables_initializer())
         # 将畫圖模式改為互動模式
         plt.ion()
         for k in range(1000):
               sess.run(train)
               # 構造圖形結構
               # 實時地輸出訓練好的W和b
               if k % 50 == 0:
                     print("第", k ,"步:","cost=", sess.run(cost), "W=", 
           

sess.run(W), “b=”, sess.run(b))

plt.cla() # 清除原有圖像

plt.plot(trainx, trainy, ‘co’, label=‘train data’) # 顯示資料

plt.plot(trainx, sess.run(y), ‘y’, label=‘train result’) # 顯示拟合

plt.pause(0.01)

plt.ioff() # 關閉互動模式

plt.close() # 關閉目前視窗

print(“訓練完成!”)

# 輸出訓練好的W和b

print(“finally_cost=”, sess.run(cost), “finally_W=”, sess.run(W),

“finally_b=”, sess.run(b))

return sess.run(W)[0], sess.run(b)[0]

tf.reduce_mean()函數用于計算張量Tensor沿着指定數軸(Tensor的某一次元)的平均值,主要用于降維或計算結果的平均值。第4步中,用梯度下降算法找最優解,通過梯度下降法為最小化損失函數增加了相關的優化操作。在訓練過程中,先執行個體化一個優化函數,并基于一定的學習率進行梯度優化訓練,如tf.train.AdamOptimizer(),該優化函數是一個尋找全局最優點的優化算法,引入了二次方梯度校正;使用minimize()操作,不僅可以優化及更新訓練的模型參數,也可以為全局步驟(Global Step)計數,函數的參數傳入損失值節點cost,再啟動一個外層的循環,優化器就會按照循環的次數沿着cost最小值的方向優化參數。第5步開始訓練,先初始化所有變量值和操作,打開plt的互動模式,開始訓練并實時顯示拟合的效果。

3.調用訓練函數

W, b = train_tf(data)

predict(50, W, b, data)

完成調用後執行該程式,經過1000步的訓練之後,得到W和b,預測結果如圖所示。

利用TensorFlow解決線性回歸問題

W和b的值如圖所示。由于代碼中每50步列印1次,1000步為0~999,是以圖4-9中隻列印到第950步。

利用TensorFlow解決線性回歸問題

基于TensorFlow實作簡單線性回歸,樣本隻有一個特征值。如果樣本有多個特征值,就需要進行多元線性回歸。

在簡單線性回歸中,橫坐标x§表示的是一個值;如果橫坐标x§對應的是一組向量,則公式的形式就變為,這個公式就是多元線性回歸的公式。如果将該公式降維為簡單線性回歸,就是偏置b,就是權重W。

對于多元線性回歸,将權重W作為一個矩陣處理,輸入的x也必須是一個矩陣,如輸入的x矩陣為,權重W為,那麼的結果就是一個的矩陣,加上偏置b,就可以得到y的值。