天天看點

TF之NN:利用神經網絡系統自動學習散點(二次函數+noise+優化修正)輸出結果可視化(matplotlib動态示範)

輸出結果

TF之NN:利用神經網絡系統自動學習散點(二次函數+noise+優化修正)輸出結果可視化(matplotlib動态示範)

代碼設計

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):  

   Weights = tf.Variable(tf.random_normal([in_size, out_size]))  

   biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)          

   Wx_plus_b = tf.matmul(inputs, Weights) + biases              

   if activation_function is None:  

       outputs = Wx_plus_b

   else:                            

       outputs = activation_function(Wx_plus_b)

   return outputs

x_data = np.linspace(-1,1,300)[:, np.newaxis]  

noise = np.random.normal(0, 0.05, x_data.shape)

y_data = np.square(x_data) - 0.5 + noise      

# define placeholder for inputs to network

xs = tf.placeholder(tf.float32, [None, 1])  

ys = tf.placeholder(tf.float32, [None, 1])

l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)  

prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data

loss = tf.reduce_mean(

   tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1])

   )

train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)  

# important step

init = tf.global_variables_initializer()

sess = tf.Session()                  

sess.run(init)                      

# plot the real data

fig = plt.figure()

ax = fig.add_subplot(1,1,1)

ax.scatter(x_data, y_data)

plt.ion()

plt.show()

for i in range(1000):

   # training

   sess.run(train_step, feed_dict={xs: x_data, ys: y_data})

   if i % 50 == 0:  

       # to visualize the result and improvement

       try:

           ax.lines.remove(lines[0])

       except Exception:

           pass

       prediction_value = sess.run(prediction, feed_dict={xs: x_data})

       # plot the prediction

       lines = ax.plot(x_data, prediction_value, 'r-', lw=5)

       plt.title('Matplotlib,NN,Efficient learning,Approach,Quadratic function --Jason Niu')

       plt.pause(0.1)