天天看點

使用tf-idf進行文本分類(模型訓練)使用tf-idf進行文本分類(模型訓練)

使用tf-idf進行文本分類(模型訓練)

資料集以及預處理請看上一篇blog

模型訓練(構模組化型)

這裡用TensorFlow1 版本

import tensorflow as tf
class LrModel(object):
	def __init__(self,config,seq_length):
		self.config=config
		self.seq_length=seq_length
		self.lr()
	def lr(self):
		self.x=tf.placeholder(tf.float32, [None, self.seq_length])
		w = tf.Variable(tf.zeros([self.seq_length, self.config.num_classes]))
        b = tf.Variable(tf.zeros([self.config.num_classes]))
        y=tf.nn.softmax(tf.matmul(self.x,w)+b)
        self._pred_cls=tf.argmax(y,1)
        self.y_=tf.placeholder(tf.float32, [None, self.config.num_classes])
        #交叉熵
        cross_entropy=tf.reduce_mean(-tf.reduce_sum(self.y_*tf.log(y),reduction_indices=[1]))
        self.loss=tf.reduce_mean(cross_entropy)
        self.train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(self.y_,1))
        self.accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


           

主函數

import time
from datetime import timedelta
from datahelper.data_process import DataProcess
from config.lr_config import LrConfig
from lr_model import LrModel
import tensorflow as tf

def get_time_dif(start_time):
    """擷取已經使用的時間"""
    end_time = time.time()
    time_dif = end_time-start_time
    return timedelta(seconds=int(round(time_dif)))
def evaluate(sess, x_, y_):
    """測試集上準曲率評估"""
    data_len = len(x_)
    batch_eval = data_get.batch_iter(x_, y_, 128)
    total_loss = 0
    total_acc = 0
    for batch_xs, batch_ys in batch_eval:
        batch_len = len(batch_xs)
        loss, acc = sess.run([model.loss, model.accuracy], feed_dict={model.x: batch_xs, model.y_: batch_ys})
        total_loss += loss * batch_len
        total_acc += acc * batch_len
    return total_loss/data_len, total_acc/data_len

def get_data():
    # 讀取資料集
    print("Loading training and validation data...")
    X_train, X_test, y_train, y_test = data_get.provide_data()
    X_train = X_train.toarray()
    X_test = X_test.toarray()
    return X_train, X_test, y_train, y_test, len(X_train[0])

def train(X_train, X_test, y_train, y_test):
    # 配置Saver
    saver = tf.train.Saver()
    # 訓練模型
    print("Training and evaluating...")
    start_time = time.time()
    total_batch = 0  # 總批次
    best_acc_val = 0.0  # 最佳驗證集準确率
    last_improved = 0  # 記錄上一次提升批次
    require_improvement = 1000  # 如果超過1000輪未提升,提前結束訓練
    flag = False
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for step in range(config.num_epochs):
            batch_train = data_get.batch_iter(X_train, y_train)
            for batch_xs, batch_ys in batch_train:
                if total_batch % config.print_per_batch == 0:
                    loss_train, acc_train = sess.run([model.loss, model.accuracy], feed_dict={model.x: X_train, model.y_: y_train})
                    loss_val, acc_val = evaluate(sess, X_test, y_test)

                    if acc_val > best_acc_val:
                        # 儲存最好結果
                        best_acc_val = acc_val
                        last_improved = total_batch
                        saver.save(sess=sess, save_path=config.lr_save_path)
                        improve_str = "*"
                    else:
                        improve_str = ""
                    time_dif = get_time_dif(start_time)
                    msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%}, '\
                           + 'Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                    print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improve_str))
                sess.run(model.train_step, feed_dict={model.x: batch_xs, model.y_: batch_ys})
                total_batch += 1

                if total_batch - last_improved > require_improvement:
                    #  驗證集準确率長期不提升,提前結束訓練
                    print("No optimization for a long time, auto-stopping...")
                    flag = True
                    break
            if flag:
                break

if __name__ == "__main__":
    config = LrConfig()
    data_get = DataProcess(config.dataset_path, config.stopwords_path, config.tfidf_model_save_path)
    X_train, X_test, y_train, y_test, seq_length = get_data()
    model = LrModel(config, seq_length)
    train(X_train, X_test, y_train, y_test)
           

繼續閱讀