- 初始化網絡
def get_init_cell(batch_size, rnn_size):
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)
cell = tf.contrib.rnn.MultiRNNCell([lstm] * )
initial_state = cell.zero_state(batch_size, tf.float32)
initial_state = tf.identity(initial_state, name='initial_state')
return (cell, initial_state)
- 輸入
def get_embed(input_data, vocab_size, embed_dim):
embedding = tf.Variable(tf.random_uniform((vocab_size, embed_dim), -, ))
embed = tf.nn.embedding_lookup(embedding, input_data)
return embed
- 單個rnn節點
def build_rnn(cell, inputs):
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)
final_state = tf.identity(final_state, name='final_state')
return (outputs, final_state)
- 整個rnn網絡
def build_nn(cell, rnn_size, input_data, vocab_size, embed_dim):
embedded_input = get_embed(input_data, vocab_size, embed_dim)
outputs, final_state = build_rnn(cell, embedded_input)
print final_state.shape
logits = tf.contrib.layers.fully_connected(outputs, vocab_size)
return (logits, final_state)
- build the graph
from tensorflow.contrib import seq2seq
train_graph = tf.Graph()
with train_graph.as_default():
vocab_size = len(int_to_vocab)
input_text, targets, lr = get_inputs()
input_data_shape = tf.shape(input_text)
cell, initial_state = get_init_cell(input_data_shape[], rnn_size)
logits, final_state = build_nn(cell, rnn_size, input_text, vocab_size, embed_dim)
# Probabilities for generating words
probs = tf.nn.softmax(logits, name='probs')
# Loss function
cost = seq2seq.sequence_loss(
logits,
targets,
tf.ones([input_data_shape[], input_data_shape[]]))
# Optimizer
optimizer = tf.train.AdamOptimizer(lr)
# Gradient Clipping
gradients = optimizer.compute_gradients(cost)
capped_gradients = [(tf.clip_by_value(grad, -, ), var) for grad, var in gradients if grad is not None]
train_op = optimizer.apply_gradients(capped_gradients)
Ending