前面的文章已經介紹,将短線個股挖掘問題轉化為深度學習處理的分類問題,并且已經完成訓練,将訓練得到的模型儲存到本地。本文将記錄如何使用Keras加載模型并進行預測的過程。
結果預測
首先,找到訓練模型儲存的目錄,加載模型:
# 加載模型
loaded_model = keras.models.load_model('./model/{}'.format(stk_code))
然後,讀入資料,将資料轉化為字典類型作為預測所使用的輸入字典,鍵為特征的索引,值為tensor。我們使用了220個特征,索引值依次為0至219。
# 讀入資料
data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code)
in_df = pd.read_csv(data_file)
# 預測用的輸入字典
temp_dict = {}
# 将資料導入輸入字典
for i in range(in_df.shape[1]):
temp_dict[i] = in_df['{}'.format(i)].tolist()
input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()}
接着,調用模型的predict方法進行預測,将預測結果儲存到清單results中。
# 進行預測
predictions = loaded_model.predict(input_dict)
results = []
for i in range(in_df.shape[0]):
results.append(predictions[i][0])
然後,我們在未來用于回測的資料後添加一列predict_result,并儲存到本地。這樣backtrader就可以通過加載本地檔案,完成基于深度學習的回測。
# 輸出到檔案
data_file = './baostock/data_ext/{}.csv'.format(stk_code)
out_df = pd.read_csv(data_file)
out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')]
out_df['predict_result'] = results
out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False)
最後,還是記得在每隻股票完成預測後,清理記憶體,以防記憶體被刷爆。
# 清理記憶體
backend.clear_session()
以上就完成了加載本地模型進行預測的過程,完整代碼如下。下一篇文章将記錄如果使用預測結果,進行多股回測。
import tensorflow as tf
import numpy as np
import pandas as pd
import os
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend
stk_code_file = './stk_data/dp_stock_list.csv'
stk_list = pd.read_csv(stk_code_file)['code'].tolist()
for stk_code in stk_list:
print('processing {} ...'.format(stk_code))
# 加載模型
loaded_model = keras.models.load_model('./model/{}'.format(stk_code))
# 讀入資料
data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code)
in_df = pd.read_csv(data_file)
# 預測用的輸入字典
temp_dict = {}
# 将資料導入輸入字典
for i in range(in_df.shape[1]):
temp_dict[i] = in_df['{}'.format(i)].tolist()
input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()}
# 進行預測
predictions = loaded_model.predict(input_dict)
results = []
for i in range(in_df.shape[0]):
results.append(predictions[i][0])
# 輸出到檔案
data_file = './baostock/data_ext/{}.csv'.format(stk_code)
out_df = pd.read_csv(data_file)
out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')]
out_df['predict_result'] = results
out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False)
# 清理記憶體
backend.clear_session()
歡迎大家關注、點贊、轉發、留言,感謝支援!
為了便于互相交流學習,已建微信群,感興趣的讀者請加微信。