天天看點

python調用tensorflow.keras搭建長短記憶型網絡(LSTM)——以預測股票收盤價為例

程式調用tensorflow.keras搭建了一個簡單長短記憶型網絡(LSTM),以上證指數為例,對資料進行标準化處理,輸入5天的\'收盤價\', \'最高價\', \'最低價\',\'開盤價\',輸出1天的\'收盤價\',利用訓練集訓練網絡後,輸出測試集的MAE

目錄

  • 程式簡介
  • 程式/資料集下載下傳
  • 代碼分析

程式簡介

程式調用tensorflow.keras搭建了一個簡單長短記憶型網絡(LSTM),以上證指數為例,對資料進行标準化處理,輸入5天的\'收盤價\', \'最高價\', \'最低價\',\'開盤價\',輸出1天的\'收盤價\',利用訓練集訓練網絡後,輸出測試集的MAE

長短記憶型網絡(LSTM)

:是一種改進之後的循環神經網絡,可以解決RNN無法處理長距離的依賴的問題。

python調用tensorflow.keras搭建長短記憶型網絡(LSTM)——以預測股票收盤價為例

程式/資料集下載下傳

點選進入下載下傳位址

python調用tensorflow.keras搭建長短記憶型網絡(LSTM)——以預測股票收盤價為例

代碼分析

導入子產品、路徑

# -*- coding: utf-8 -*-
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.layers import Input,Dense,LSTM,GRU,BatchNormalization
from tensorflow.keras.layers import PReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import mean_absolute_error as MAE
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import numpy as np
import os

#用來正常顯示中文标簽
plt.rcParams[\'font.sans-serif\']=[\'SimHei\'] 
#用來正常顯示負号
plt.rcParams[\'axes.unicode_minus\']=False
#路徑目錄
baseDir = \'\'#目前目錄
staticDir = os.path.join(baseDir,\'Static\')#靜态檔案目錄
resultDir = os.path.join(baseDir,\'Result\')#結果檔案目錄
           

讀取資料,檢視5行

#讀取資料
data = pd.read_csv(staticDir+\'/000001.csv\',encoding=\'gbk\').iloc[-100:,:]
data = data.set_index([\'日期\'])
data.head()
           
股票代碼 名稱 收盤價 最高價 最低價 開盤價 前收盤 漲跌額 漲跌幅 成交量 成交金額
日期
2019/9/16 \'000001 上證指數 3030.7544 3042.9284 3020.0495 3041.9220 3031.2351 -0.4807 -0.0159 221878959 2.37E+11
2019/9/17 \'000001 上證指數 2978.1178 3023.7109 2970.5704 3023.7109 3030.7544 -52.6366 -1.7367 223338061 2.38E+11
2019/9/18 \'000001 上證指數 2985.6586 2996.4022 2982.4003 2984.0837 2978.1178 7.5408 0.2532 168046699 2.00E+11
2019/9/19 \'000001 上證指數 2999.2789 2999.2789 2975.3978 2992.9222 2985.6586 13.6203 0.4562 162690615 1.93E+11
2019/9/20 \'000001 上證指數 3006.4467 3011.3400 2996.1929 3004.8142 2999.2789 7.1678 0.239 182145302 2.18E+11

對輸入輸出進行标準化,檢視5行

#标準化資料集
outputCol = [\'收盤價\']#輸出列
inputCol = [\'收盤價\', \'最高價\',\'最低價\',\'開盤價\']#輸入列
X = data[inputCol]
Y = data[outputCol]
xScaler = StandardScaler()
yScaler = StandardScaler()
X = xScaler.fit_transform(X)
Y = yScaler.fit_transform(Y)
X[:5,:]
           
array([[0.94704786, 0.91606531, 0.98497021, 1.04253169],
       [0.21175964, 0.65151178, 0.33108448, 0.80913257],
       [0.31709816, 0.2755725 , 0.48742125, 0.30125807],
       [0.50736208, 0.31517397, 0.39488046, 0.41453503],
       [0.60749011, 0.48121048, 0.66969587, 0.5669466 ]])
           

将資料按時間步進行整理,時間步這裡設定為5天,輸入為1天

#按時間步組成輸入輸出集
timeStep = 5#輸入天數
outStep = 1#輸出天數
xAll = list()
yAll = list()
#按時間步整理資料 輸入資料尺寸是(timeStep,5) 輸出尺寸是(outSize)
for row in range(data.shape[0]-timeStep-outStep+1):
    x = X[row:row+timeStep]
    y = Y[row+timeStep:row+timeStep+outStep]
    xAll.append(x)
    yAll.append(y)
xAll = np.array(xAll).reshape(-1,timeStep,len(inputCol))
yAll = np.array(yAll).reshape(-1,outStep)
print(\'輸入集尺寸\',xAll.shape)
print(\'輸出集尺寸\',yAll.shape)
           
輸入集尺寸 (95, 5, 4)
輸出集尺寸 (95, 1)
           

資料集分割為訓練集和測試集

#分成測試集,訓練集
testRate = 0.2#測試比例
splitIndex = int(xAll.shape[0]*(1-testRate))
xTrain = xAll[:splitIndex]
xTest = xAll[splitIndex:]
yTrain = yAll[:splitIndex]
yTest = yAll[splitIndex:]
           

搭建一個簡單的LSTM網絡,結構下文會列印出來

def buildLSTM(timeStep,inputColNum,outStep,learnRate=1e-4):
    \'\'\'
    搭建LSTM網絡,激活函數為tanh
    timeStep:輸入時間步
    inputColNum:輸入列數
    outStep:輸出時間步
    learnRate:學習率    
    \'\'\'
    #輸入層
    inputLayer = Input(shape=(timeStep,inputColNum))

    #中間層
    middle = LSTM(100,activation=\'tanh\')(inputLayer)
    middle = Dense(100,activation=\'tanh\')(middle)

    #輸出層 全連接配接
    outputLayer = Dense(outStep)(middle)
    
    #模組化
    model = Model(inputs=inputLayer,outputs=outputLayer)
    optimizer = Adam(lr=learnRate)
    model.compile(optimizer=optimizer,loss=\'mse\') 
    model.summary()
    return model

#搭建LSTM
lstm = buildLSTM(timeStep=timeStep,inputColNum=len(inputCol),outStep=outStep,learnRate=1e-4)
           
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 5, 4)              0         
_________________________________________________________________
lstm (LSTM)                  (None, 100)               42000     
_________________________________________________________________
dense (Dense)                (None, 100)               10100     
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 101       
=================================================================
Total params: 52,201
Trainable params: 52,201
Non-trainable params: 0
_________________________________________________________________
           

利用訓練集對網絡進行訓練

#訓練網絡
epochs = 1000#疊代次數
batchSize = 500#批處理量
lstm.fit(xTrain,yTrain,epochs=epochs,verbose=0,batch_size=batchSize) 
           

對測試集進行預測,儲存預測結果,檢視5行

#預測 測試集對比
yPredict = lstm.predict(xTest)
yPredict = yScaler.inverse_transform(yPredict)[:,0]
yTest = yScaler.inverse_transform(yTest)[:,0]
result = {\'觀測值\':yTest,\'預測值\':yPredict}
result = pd.DataFrame(result)
result.index = data.index[timeStep+xTrain.shape[0]:result.shape[0]+timeStep+xTrain.shape[0]]
result.to_excel(resultDir+\'/預測結果.xlsx\')
result.head()
           
觀測值 預測值
日期
2020/1/15 3090.0379 3119.753662
2020/1/16 3074.0814 3103.595947
2020/1/17 3075.4955 3085.278809
2020/1/20 3095.7873 3079.762451
2020/1/21 3052.1419 3094.907471

計算測試集MAE,進行可視化

mae = MAE(result[\'觀測值\'],result[\'預測值\'])
print(\'模型測試集MAE\',mae)
#可視化
fig,ax = plt.subplots(1,1)
ax.plot(result.index,result[\'預測值\'],label=\'預測值\')
ax.plot(result.index,result[\'觀測值\'],label=\'觀測值\')
ax.set_title(\'LSTM預測效果,MAE:%2f\'%mae)
ax.legend()
ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
fig.savefig(resultDir+\'/預測折線圖.png\',dpi=500)
           
模型測試集MAE 37.06394592927633
           
python調用tensorflow.keras搭建長短記憶型網絡(LSTM)——以預測股票收盤價為例
python調用tensorflow.keras搭建長短記憶型網絡(LSTM)——以預測股票收盤價為例