我不熟悉tensorflow。下面的代碼可以成功運作,沒有任何錯誤。在前10行輸出中,計算速度很快,并且輸出(在最後一行中定義)逐行飛行。然而,随着疊代次數的增加,計算變得越來越慢,最終變得無法忍受。是以我想知道是否有任何修改可以加速這個過程。在
以下是對該代碼的簡要說明:
該代碼将單隐層神經網絡應用于資料集。它的目的是找到速率[0]和速率[1]的最佳參數,這兩個參數将影響損耗函數。在訓練的每一步中,一個元組被輸入到模型中,并且立即評估元組的準确性(這種資料在現實世界中以流的形式出現)。在import tensorflow as tf
import numpy as np
n_hidden=50
n_input=37
n_output=2
data_raw=np.genfromtxt(r'data.csv',delimiter=",",dtype=None)
data_info=np.genfromtxt(r'data2.csv',delimiter=",",dtype=None)
def pre_process( tuple):
ans = []
temp = [0 for i in range(24)]
temp[int(tuple[0])] = 1
# np.append(ans,np.array(temp))
ans.extend(temp)
temp = [0 for i in range(7)]
temp[int(tuple[1]) - 1] = 1
ans.extend(temp)
# np.append(ans,np.array(temp))
temp = [0 for i in range(3)]
temp[int(tuple[3])] = 1
ans.extend(temp)
temp = [0 for i in range(2)]
temp[int(tuple[4])] = 1
ans.extend(temp)
ans.extend([int(tuple[5])])
return np.array(ans)
x=tf.placeholder(tf.float32, shape=[1,n_input])
y_=tf.placeholder(tf.float32,shape=[n_output])
y_r=tf.placeholder(tf.float32,shape=[n_output])
W1=tf.Variable(tf.random_uniform([n_input, n_hidden]))
b1=tf.Variable(tf.zeros([n_hidden]))
W2=tf.Variable(tf.zeros([n_hidden,n_output]))
b2=tf.Variable(tf.zeros([n_output]))
logits_1 = tf.matmul(x, W1) + b1
relu_layer= tf.nn.relu(logits_1)
logits_2 = tf.matmul(relu_layer, W2) + b2
correct_prediction = tf.equal(tf.argmax(logits_2,1), tf.argmax(y_,0))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
rate=[0,0]
for i in range(-100,200,10):
rate[0]=i;
for j in range(-100,i,10):
rate[1]=j
loss=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits_2)*[rate[0],rate[1]])
# loss2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=y_r, logits=logits_2)*[rate[2],rate[3]])
# loss=loss1+loss2
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
data_line=1
accur=0
local_local=0
remote_remote=0
local_remote=0
remote_local=0
total=0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(200):
# print(int(data_raw[data_line][0]),data_info[i][0])
if i>100:
total+=1
if int(data_raw[data_line][0])==data_info[i][0]:
sess.run(train_step,feed_dict={x:pre_process(data_info[i]).reshape(1,-1),y_:[1,0],y_r:[0,1]})
# print(sess.run(logits_2,{x:pre_process(data_info[i]).reshape(1,-1), y_: #[1,0]}))
data_line+=1;
if data_line==len(data_raw):
break
if i>100:
acc=accuracy.eval(feed_dict={x: pre_process(data_info[i]).reshape(1,-1), y_: [1,0], y_r:[0,1]})
local_local+=acc
local_remote+=1-acc
accur+=acc
else:
sess.run(train_step,feed_dict={x:pre_process(data_info[i]).reshape(1,-1),y_:[0,1], y_r:[1,0]})
# print(sess.run(logits_2,{x: pre_process(data_info[i]).reshape(1,-1), y_: #[0,1]}))
if i>100:
acc=accuracy.eval(feed_dict={x: pre_process(data_info[i]).reshape(1,-1), y_: [0,1], y_r:[1,0]})
remote_remote+=acc
remote_local+=1-acc
accur+=acc
print("correctness: (%.3d,%.3d): \t%.2f %.2f %.2f %.2f %.2f" % (rate[0],rate[1],accur/total,local_local/total,local_remote/total,remote_local/total,remote_remote/total))