實驗原理:
線性回歸是用來度量變量間關系的統計技術。該算法的實作并不複雜,但可以适用于很多情形。正是因為這些原因,以線性回歸作為開始學習TensorFlow的開始。
不管在兩個變量(簡單回歸)或多個變量(多元回歸)情形下,線性回歸都是對一個依賴變量,多個獨立變量xi,一個随機值b間的關系模組化。利用TensorFlow實作一個簡單的線性回歸模型:分析一些代碼基礎及說明如何在學習過程中調用各種重要元件,比如cost function或梯度下降算法
運作代碼:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() #保證placer
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
#設定訓練參數
learning_rate=0.01
training_epochs=1000
display_step=50
# 訓練資料
train_X=np.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])
train_Y=np.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,2.827,3.465,1.65,2.904,2.42,2.94,1.3])
n_samples=train_X.shape[0]
#構造計算圖
X=tf.placeholder("float")
Y=tf.placeholder("float")
#設定模型的初始權重
W=tf.Variable(np.random.randn(),name="weight")
b=tf.Variable(np.random.randn(),name='bias')
#構造線性回歸模型
pred=tf.add(tf.multiply(X,W),b)
#損失函數,即均方差
cost=tf.reduce_sum(tf.pow(pred-Y,2))/(2*n_samples)
#使用梯度下降法求最小值,即最優解
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
#初始化全部變量
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#調用會話對象sess的run方法,運作計算圖,即開始訓練模型
for epoch in range(training_epochs):
for(x,y) in zip(train_X,train_Y):
sess.run(optimizer,feed_dict={X:x,Y:y})
#Display logs per epoch step
if (epoch+1) % display_step==0:
c=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(c),"W=",sess.run(W),"b=",sess.run(b))
#訓練模型的代價函數。
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Train cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b))
plt.plot(train_X,train_Y,'ro',label='Original data')
plt.plot(train_X,sess.run(W)*train_X+sess.run(b),label="Fitting line")
plt.legend()
plt.show()
運作結果:
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsISPrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdsATOfd3bkFGazxCMx8VesATMfhHLlN3XnxCMwEzX0xiRGZkRGZ0Xy9GbvNGLpZTY1EmMZVDUSFTU4VFRR9Fd4VGdsYTMfVmepNHLrJXYtJXZ0F2dvwVZnFWbp1zczV2YvJHctM3cv1Ce-cmbw5SZ4AzYhBDZxUmZxYjM1QGN1cTZmhDMyEDNkhjYhV2Mj9CXzAzLchDMxIDMy8CXn9Gbi9CXzV2Zh1WavwVbvNmLvR3YxUjL4M3Lc9CX6MHc0RHaiojIsJye.png)