使用tf-idf進行文本分類(模型訓練)
資料集以及預處理請看上一篇blog
模型訓練(構模組化型)
這裡用TensorFlow1 版本
import tensorflow as tf
class LrModel(object):
def __init__(self,config,seq_length):
self.config=config
self.seq_length=seq_length
self.lr()
def lr(self):
self.x=tf.placeholder(tf.float32, [None, self.seq_length])
w = tf.Variable(tf.zeros([self.seq_length, self.config.num_classes]))
b = tf.Variable(tf.zeros([self.config.num_classes]))
y=tf.nn.softmax(tf.matmul(self.x,w)+b)
self._pred_cls=tf.argmax(y,1)
self.y_=tf.placeholder(tf.float32, [None, self.config.num_classes])
#交叉熵
cross_entropy=tf.reduce_mean(-tf.reduce_sum(self.y_*tf.log(y),reduction_indices=[1]))
self.loss=tf.reduce_mean(cross_entropy)
self.train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(self.y_,1))
self.accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
主函數
import time
from datetime import timedelta
from datahelper.data_process import DataProcess
from config.lr_config import LrConfig
from lr_model import LrModel
import tensorflow as tf
def get_time_dif(start_time):
"""擷取已經使用的時間"""
end_time = time.time()
time_dif = end_time-start_time
return timedelta(seconds=int(round(time_dif)))
def evaluate(sess, x_, y_):
"""測試集上準曲率評估"""
data_len = len(x_)
batch_eval = data_get.batch_iter(x_, y_, 128)
total_loss = 0
total_acc = 0
for batch_xs, batch_ys in batch_eval:
batch_len = len(batch_xs)
loss, acc = sess.run([model.loss, model.accuracy], feed_dict={model.x: batch_xs, model.y_: batch_ys})
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss/data_len, total_acc/data_len
def get_data():
# 讀取資料集
print("Loading training and validation data...")
X_train, X_test, y_train, y_test = data_get.provide_data()
X_train = X_train.toarray()
X_test = X_test.toarray()
return X_train, X_test, y_train, y_test, len(X_train[0])
def train(X_train, X_test, y_train, y_test):
# 配置Saver
saver = tf.train.Saver()
# 訓練模型
print("Training and evaluating...")
start_time = time.time()
total_batch = 0 # 總批次
best_acc_val = 0.0 # 最佳驗證集準确率
last_improved = 0 # 記錄上一次提升批次
require_improvement = 1000 # 如果超過1000輪未提升,提前結束訓練
flag = False
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(config.num_epochs):
batch_train = data_get.batch_iter(X_train, y_train)
for batch_xs, batch_ys in batch_train:
if total_batch % config.print_per_batch == 0:
loss_train, acc_train = sess.run([model.loss, model.accuracy], feed_dict={model.x: X_train, model.y_: y_train})
loss_val, acc_val = evaluate(sess, X_test, y_test)
if acc_val > best_acc_val:
# 儲存最好結果
best_acc_val = acc_val
last_improved = total_batch
saver.save(sess=sess, save_path=config.lr_save_path)
improve_str = "*"
else:
improve_str = ""
time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%}, '\
+ 'Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improve_str))
sess.run(model.train_step, feed_dict={model.x: batch_xs, model.y_: batch_ys})
total_batch += 1
if total_batch - last_improved > require_improvement:
# 驗證集準确率長期不提升,提前結束訓練
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
if __name__ == "__main__":
config = LrConfig()
data_get = DataProcess(config.dataset_path, config.stopwords_path, config.tfidf_model_save_path)
X_train, X_test, y_train, y_test, seq_length = get_data()
model = LrModel(config, seq_length)
train(X_train, X_test, y_train, y_test)