天天看点

BP神经网络和Stacking

from keras.models import Sequential  # 按顺序建立的神经网络模型
from keras.layers import Dense  # dense 全连接层 
import numpy as np 
import matplotlib.pyplot as plt 
%matplotlib inline
import tensorflow as tf
           
Using TensorFlow backend.
           
import pandas as pd 
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import KFold
           
train = pd.read_csv('train_1.csv',index_col = '任务id')
test = pd.read_csv('test_1.csv',index_col = '任务id')
           

神经网络

train
           
总里程 车辆吨位 车辆长度 diff_发车_靠车 diff_到达_发车 diff_卸货_到达 time_cost is_weekend is_fest 业务类型_速运 ... 运输等级_二级运输 线路总成本 线路价格(不含税) 线路指导价(不含税) 调价审核完成时间 调价类型 调价ECP创建时间 交易开始时间 交易成功时间 B端议价反馈数量
任务id
4713 0.004551 0.052632 0.092593 0.020582 0.039035 0.126818 0.101090 1.0 0.0 1.0 ... 1.0 281.82 238 282 N 调低 2019-05-01 11:09:01.000000000 2019-05-01 11:10:00.000000000 2019-05-01 23:09:00.000000000
4721 0.042093 0.438596 0.500000 0.048971 0.100071 0.126818 0.188934 0.0 0.0 0.0 ... 1.0 672.47 1010 672 2019-05-06 03:33:44 调高 2019-05-06 03:33:42.000000000 2019-05-06 03:12:39.000000000 2019-05-06 03:44:00.000000000 1
4722 0.042093 0.438596 0.500000 0.048971 0.106458 0.126818 0.195208 0.0 0.0 0.0 ... 1.0 681.23 1100 681 2019-05-09 04:56:42 调高 2019-05-09 04:56:40.000000000 2019-05-09 04:02:03.000000000 2019-05-09 05:05:00.000000000 1
4723 0.042093 0.438596 0.500000 0.048971 0.106458 0.126818 0.195208 0.0 0.0 0.0 ... 1.0 681.23 1100 681 2019-05-09 04:59:40 调高 2019-05-09 04:59:38.000000000 2019-05-09 04:00:35.000000000 2019-05-09 05:32:00.000000000 2
4727 0.042093 0.438596 0.500000 0.048971 0.103620 0.126818 0.192419 0.0 0.0 0.0 ... 1.0 618.89 1000 619 2019-05-09 23:33:38 调高 2019-05-09 23:33:37.000000000 2019-05-09 20:35:30.000000000 2019-05-09 23:45:00.000000000 1
4725 0.042093 0.438596 0.500000 0.048971 0.103620 0.126818 0.192419 0.0 0.0 0.0 ... 1.0 618.89 620 619 N 未调整 2019-05-09 20:29:41.000000000 2019-05-09 20:34:28.000000000 2019-05-09 21:33:00.000000000
4726 0.042093 0.438596 0.500000 0.048971 0.103620 0.126818 0.192419 0.0 0.0 0.0 ... 1.0 618.89 1000 619 2019-05-09 23:29:45 调高 2019-05-09 23:29:43.000000000 2019-05-09 20:35:00.000000000 2019-05-09 23:45:00.000000000 1
4728 0.042093 0.438596 0.500000 0.048971 0.106458 0.126818 0.195208 1.0 0.0 0.0 ... 1.0 622.78 860 623 2019-05-11 02:34:18 调高 2019-05-11 02:34:16.000000000 2019-05-11 01:58:45.000000000 2019-05-11 03:53:00.000000000 2
4729 0.042093 0.438596 0.500000 0.048971 0.106458 0.126818 0.195208 0.0 0.0 0.0 ... 1.0 446.78 860 447 2019-05-16 04:31:29 调高 2019-05-16 04:31:27.000000000 2019-05-16 02:42:02.000000000 2019-05-16 04:44:00.000000000
4730 0.042093 0.438596 0.500000 0.048971 0.106458 0.126818 0.195208 0.0 0.0 0.0 ... 1.0 446.78 770 447 2019-05-17 06:38:57 调高 2019-05-17 06:38:55.000000000 2019-05-17 04:28:59.000000000 2019-05-17 06:47:00.000000000
5455 0.042093 0.438596 0.500000 0.169624 0.106458 0.126818 0.313727 0.0 0.0 0.0 ... 1.0 612.39 1000 612 2019-05-22 18:33:33 调高 2019-05-22 18:33:32.000000000 2019-05-22 15:07:04.000000000 2019-05-22 19:22:00.000000000 3
12888 0.078878 0.438596 0.500000 0.127040 0.159688 0.126818 0.324185 0.0 0.0 0.0 ... 1.0 1121.41 1713 1121 2019-05-23 19:09:38 调高 2019-05-23 19:09:37.000000000 2019-05-23 18:38:04.000000000 2019-05-23 21:27:00.000000000 1
4720 0.004551 0.052632 0.092593 0.027679 0.044713 0.126818 0.113639 0.0 0.0 1.0 ... 1.0 281.82 259 282 N 调低 2019-05-29 17:19:33.000000000 2019-05-29 17:23:59.000000000 2019-05-29 17:24:00.000000000
4718 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 0.0 0.0 1.0 ... 1.0 281.82 259 282 N 调低 2019-05-29 17:16:38.000000000 2019-05-29 17:20:40.000000000 2019-05-29 17:21:00.000000000
4719 0.004551 0.052632 0.092593 0.119943 0.042583 0.126818 0.202180 0.0 0.0 1.0 ... 1.0 286.26 259 286 N 调低 2019-05-29 17:15:11.000000000 2019-05-29 17:19:36.000000000 2019-05-29 17:20:00.000000000
5067 0.042093 0.438596 0.500000 0.041874 0.103620 0.126818 0.185448 1.0 0.0 0.0 ... 1.0 609.14 917 609 2019-05-29 16:51:55 调高 2019-05-29 16:51:50.000000000 2019-05-29 16:55:48.000000000 2019-05-30 11:06:00.000000000
5318 0.004551 0.052632 0.092593 0.020582 0.039035 0.126818 0.101090 1.0 0.0 1.0 ... 1.0 281.82 257 282 N 调低 2019-05-31 18:10:18.000000000 2019-05-31 18:10:16.000000000 2019-05-31 18:29:00.000000000
5496 0.004551 0.052632 0.092593 0.020582 0.039035 0.126818 0.101090 1.0 0.0 1.0 ... 1.0 281.82 232 282 N 调低 2019-05-31 17:49:23.000000000 2019-05-31 17:53:18.000000000 2019-05-31 18:22:00.000000000
5351 0.004551 0.052632 0.092593 0.020582 0.039035 0.126818 0.101090 1.0 0.0 1.0 ... 1.0 281.82 256 282 N 调低 2019-05-29 17:31:42.000000000 2019-05-29 17:36:23.000000000 2019-05-29 17:36:00.000000000
4740 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 1.0 0.0 1.0 ... 1.0 281.82 256 282 N 调低 2019-05-30 12:52:08.000000000 2019-05-30 12:55:51.000000000 2019-05-30 13:18:00.000000000
4746 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 1.0 0.0 1.0 ... 1.0 281.82 257 282 N 调低 2019-05-30 12:53:42.000000000 2019-05-30 12:53:40.000000000 2019-05-30 13:20:00.000000000
4991 0.042093 0.438596 0.500000 0.041874 0.103620 0.126818 0.185448 1.0 0.0 0.0 ... 1.0 609.14 917 609 2019-05-29 18:54:00 调高 2019-05-29 18:53:54.000000000 2019-05-29 18:58:31.000000000 2019-05-30 17:55:00.000000000
4967 0.042093 0.438596 0.500000 0.048971 0.103620 0.126818 0.192419 1.0 0.0 0.0 ... 1.0 618.89 917 619 2019-05-29 18:54:45 调高 2019-05-29 18:54:40.000000000 2019-05-29 18:59:16.000000000 2019-05-30 17:56:00.000000000
5456 0.042093 0.438596 0.500000 0.254791 0.106458 0.126818 0.397388 1.0 0.0 0.0 ... 1.0 905.30 998 905 N 调高 2019-05-27 15:10:58.000000000 2019-05-27 15:10:38.000000000 2019-05-27 16:34:00.000000000
5474 0.042093 0.438596 0.500000 0.212207 0.106458 0.126818 0.355558 1.0 0.0 0.0 ... 1.0 846.85 998 847 N 调高 2019-05-27 15:11:32.000000000 2019-05-27 15:11:18.000000000 2019-05-27 16:37:00.000000000
5407 0.042093 0.438596 0.500000 0.084457 0.106458 0.126818 0.230067 1.0 0.0 0.0 ... 1.0 671.49 917 671 2019-05-29 16:57:30 调高 2019-05-29 16:57:22.000000000 2019-05-29 17:01:53.000000000 2019-05-30 17:58:00.000000000
4886 0.004551 0.052632 0.092593 0.034776 0.042583 0.126818 0.118519 1.0 0.0 1.0 ... 1.0 281.82 257 282 N 调低 2019-05-29 20:07:17.000000000 2019-05-29 20:07:15.000000000 2019-05-29 20:43:00.000000000
5234 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 1.0 0.0 1.0 ... 1.0 281.82 257 282 N 调低 2019-05-30 12:57:49.000000000 2019-05-30 12:57:47.000000000 2019-05-30 13:16:00.000000000
5288 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 1.0 0.0 1.0 ... 1.0 281.82 256 282 N 调低 2019-05-30 12:57:29.000000000 2019-05-30 13:02:08.000000000 2019-05-30 13:19:00.000000000
5181 0.004551 0.052632 0.092593 0.020582 0.039035 0.126818 0.101090 1.0 0.0 1.0 ... 1.0 281.82 257 282 N 调低 2019-05-31 18:13:07.000000000 2019-05-31 18:17:49.000000000 2019-05-31 18:22:00.000000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
4243 0.001138 0.000000 0.000000 0.020582 0.035486 0.024691 0.063001 0.0 0.0 1.0 ... 0.0 204.54 165 179 N 调低 2019-12-19 13:48:54.000000000 2019-12-19 13:48:52.000000000 2019-12-19 13:52:00.000000000
4319 0.001138 0.000000 0.000000 0.020582 0.042583 0.024691 0.069973 0.0 0.0 1.0 ... 0.0 204.54 165 179 N 调低 2019-12-19 13:51:28.000000000 2019-12-19 13:51:25.000000000 2019-12-19 13:59:00.000000000
4155 0.000000 0.000000 0.000000 0.020582 0.035486 0.024691 0.063001 0.0 0.0 1.0 ... 0.0 192.29 127 127 N 未调整 2019-12-19 10:47:10.000000000 2019-12-19 10:47:07.000000000 2019-12-19 16:43:00.000000000
4275 0.001138 0.000000 0.000000 0.020582 0.035486 0.024691 0.063001 0.0 0.0 1.0 ... 0.0 204.54 165 179 N 调低 2019-12-19 13:54:28.000000000 2019-12-19 13:54:26.000000000 2019-12-19 14:56:00.000000000
17263 0.042093 0.438596 0.500000 0.041874 0.106458 0.126818 0.188236 0.0 0.0 0.0 ... 1.0 780.17 998 750 N 调高 2019-12-26 22:01:20.000000000 2019-12-26 22:01:17.000000000 2019-12-26 23:29:00.000000000
17244 0.042093 0.438596 0.500000 0.041874 0.106458 0.126818 0.188236 0.0 0.0 0.0 ... 1.0 838.62 1008 808 2019-12-27 22:17:22 调高 2019-12-27 22:17:15.000000000 2019-12-27 02:00:00.000000000 2019-12-27 23:48:00.000000000
15535 0.078878 0.438596 0.500000 0.041874 0.159688 0.094650 0.229625 0.0 0.0 0.0 ... 1.0 1351.53 986 1301 N 调低 2019-12-26 13:55:02.000000000 2019-12-26 13:55:00.000000000 2019-12-26 14:04:00.000000000
16987 0.042093 0.438596 0.500000 0.041874 0.106458 0.174897 0.204527 0.0 0.0 0.0 ... 1.0 780.17 998 750 N 调高 2019-12-26 22:02:22.000000000 2019-12-26 22:02:19.000000000 2019-12-27 08:54:00.000000000
17064 0.042093 0.438596 0.500000 0.041874 0.106458 0.174897 0.204527 0.0 0.0 0.0 ... 1.0 838.62 998 808 N 调高 2019-12-26 22:01:50.000000000 2019-12-26 22:01:48.000000000 2019-12-27 02:14:00.000000000
12745 0.974592 1.000000 1.000000 0.020582 0.074521 0.126818 0.135949 0.0 0.0 1.0 ... 0.0 12224.22 15000 15000 N 未调整 2019-12-19 16:36:06.000000000 2019-12-19 16:41:01.000000000 2019-12-20 13:33:00.000000000
12804 0.957527 1.000000 1.000000 0.006388 0.816182 0.135802 0.853594 0.0 0.0 0.0 ... 0.0 12260.46 14670 11576 2019-12-31 14:32:46 调高 2019-12-31 14:31:32.000000000 2019-12-31 18:00:00.000000000 2020-01-02 14:03:00.000000000
12815 0.973781 1.000000 1.000000 0.004968 0.000710 0.135802 0.051149 0.0 0.0 0.0 ... 0.0 12518.74 17431 11823 2019-12-31 11:10:05 调高 2019-12-31 11:09:54.000000000 2019-12-31 14:00:00.000000000 2019-12-31 14:15:00.000000000
17006 0.042093 0.438596 0.500000 0.041874 0.106458 0.174897 0.204527 1.0 0.0 0.0 ... 1.0 663.27 1181 633 2020-01-01 09:56:52 调高 2020-01-01 09:56:51.000000000 2019-12-27 02:00:00.000000000 2020-01-01 10:29:00.000000000 1
12822 0.952218 1.000000 1.000000 0.041874 0.887154 0.135802 0.958170 1.0 0.0 1.0 ... 0.0 12232.46 19500 11552 2020-01-04 15:55:25 调高 2020-01-04 15:55:19.000000000 2019-12-26 18:00:00.000000000 2020-01-04 18:07:00.000000000 1
4070 0.000000 0.000000 0.000000 0.024131 0.039035 0.024691 0.069973 1.0 0.0 1.0 ... 0.0 192.29 111 111 N 未调整 2019-12-19 14:56:18.000000000 2019-12-19 14:56:15.000000000 2019-12-19 23:11:00.000000000
4331 0.001138 0.000000 0.000000 0.020582 0.042583 0.024691 0.069973 1.0 0.0 1.0 ... 0.0 204.54 165 179 N 调低 2019-12-19 13:51:28.000000000 2019-12-19 13:51:25.000000000 2019-12-19 13:59:00.000000000
4167 0.000000 0.000000 0.000000 0.020582 0.035486 0.024691 0.063001 1.0 0.0 1.0 ... 0.0 192.29 127 127 N 未调整 2019-12-19 10:47:10.000000000 2019-12-19 10:47:07.000000000 2019-12-19 16:43:00.000000000
4283 0.001138 0.000000 0.000000 0.020582 0.035486 0.024691 0.063001 1.0 0.0 1.0 ... 0.0 204.54 165 179 N 调低 2019-12-19 13:54:28.000000000 2019-12-19 13:54:26.000000000 2019-12-19 14:56:00.000000000
17191 0.042093 0.438596 0.500000 0.041874 0.106458 0.076132 0.171063 0.0 0.0 0.0 ... 1.0 873.35 917 851 2019-12-26 21:50:15 调高 2019-12-26 21:50:14.000000000 2019-12-26 21:50:12.000000000 2019-12-28 09:33:00.000000000
12814 0.973781 1.000000 1.000000 0.004968 0.001419 0.135802 0.051846 1.0 0.0 0.0 ... 0.0 12520.24 17431 11824 2019-12-31 11:09:22 调高 2019-12-31 11:09:12.000000000 2019-12-31 14:00:00.000000000 2019-12-31 14:14:00.000000000
17206 0.042093 0.438596 0.500000 0.041874 0.106458 0.126818 0.188236 0.0 0.0 0.0 ... 1.0 1006.93 1036 985 2019-12-26 21:55:16 调高 2019-12-26 21:55:16.000000000 2019-12-26 21:55:13.000000000 2019-12-26 22:33:00.000000000
12813 0.973781 1.000000 1.000000 0.041874 0.000710 0.135802 0.087402 1.0 0.0 0.0 ... 0.0 12518.74 17431 11823 2019-12-31 11:10:46 调高 2019-12-31 11:10:38.000000000 2019-12-31 14:00:00.000000000 2019-12-31 14:14:00.000000000
12800 0.957527 1.000000 1.000000 0.041874 0.780696 0.567901 1.000000 0.0 0.0 1.0 ... 0.0 12265.46 16695 11585 N 调高 2019-12-21 23:39:51.000000000 2019-12-21 23:39:49.000000000 2019-12-22 11:33:00.000000000
17230 0.042093 0.438596 0.500000 0.041874 0.106458 0.174897 0.204527 0.0 0.0 0.0 ... 1.0 955.53 917 925 N 未调整 2019-12-26 21:58:42.000000000 2019-12-26 21:58:40.000000000 2019-12-27 20:55:00.000000000
12754 0.952218 1.000000 1.000000 0.041164 0.887154 0.135802 0.957473 0.0 0.0 1.0 ... 0.0 12187.62 15000 15000 N 未调整 2019-12-10 18:06:02.000000000 2019-12-10 18:06:00.000000000 2019-12-11 13:47:00.000000000
15559 0.078878 0.438596 0.500000 0.020582 0.163236 0.084362 0.208710 0.0 0.0 0.0 ... 1.0 1144.01 1712 1094 N 调高 2019-12-24 19:24:11.000000000 2019-12-24 19:24:08.000000000 2019-12-25 16:48:00.000000000
15514 0.079067 0.438596 0.500000 0.020582 0.283889 0.126818 0.341614 0.0 0.0 0.0 ... 1.0 1151.59 999 1101 N 调低 2019-12-24 18:10:27.000000000 2019-12-24 18:10:22.000000000 2019-12-25 13:52:00.000000000
12802 0.966629 1.000000 1.000000 0.041874 0.242725 0.126818 0.322093 0.0 1.0 1.0 ... 0.0 12265.46 16695 11585 N 调高 2019-12-21 23:39:44.000000000 2019-12-21 23:39:42.000000000 2019-12-22 11:30:00.000000000
11003 0.004551 0.052632 0.092593 0.020582 0.042583 0.126818 0.104576 0.0 1.0 1.0 ... 1.0 272.71 232 257 N 调低 2019-12-23 15:33:19.000000000 2019-12-23 15:33:15.000000000 2019-12-24 10:58:00.000000000
15373 0.078878 0.438596 0.500000 0.112846 0.152590 0.126818 0.303270 0.0 1.0 1.0 ... 1.0 1351.53 1635 1301 N 调高 2019-12-24 15:43:58.000000000 2019-12-24 15:43:40.000000000 2019-12-24 16:18:00.000000000

1703 rows × 43 columns

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)','线路指导价(不含税)']].values
scaler = MinMaxScaler()
y = scaler.fit_transform(y)
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)
           
array([[ 1517.47,   999.  ,  1517.  ],
       [ 1165.07,  1712.  ,  1165.  ],
       [ 1073.76,  1810.  ,  1024.  ],
       [  403.28,   594.  ,   394.  ],
       [ 1018.61,  1019.  ,  1019.  ],
       [12154.16, 15000.  , 11494.  ],
       [  272.71,   232.  ,   257.  ],
       [ 1022.51,  1023.  ,  1023.  ],
       [ 1213.12,  1733.  ,  1163.  ],
       [  462.33,   603.  ,   453.  ],
       [  934.83,  1036.  ,   935.  ],
       [ 1051.46,   999.  ,  1051.  ],
       [11334.36, 11553.  , 10618.  ],
       [  204.54,   165.  ,   190.  ],
       [ 1133.76,  1055.  ,  1084.  ],
       [ 1398.31,  1635.  ,  1398.  ],
       [12187.62, 15500.  , 11518.  ],
       [11234.96, 11819.  , 10519.  ],
       [10811.97, 13800.  , 10139.  ],
       [  669.11,   917.  ,   639.  ],
       [ 1407.54,   986.  ,  1358.  ],
       [ 1351.53,  1635.  ,  1301.  ],
       [  953.72,  1927.  ,   904.  ],
       [ 1230.71,  1682.  ,  1182.  ],
       [ 1070.39,  1036.  ,  1040.  ],
       [ 1347.54,  1651.  ,  1298.  ],
       [ 1076.4 ,  1827.  ,  1027.  ],
       [10988.58, 15000.  , 10989.  ],
       [ 1470.67,   999.  ,  1471.  ],
       [12720.98, 14342.  , 12014.  ],
       [ 1407.54,  1575.  ,  1358.  ],
       [  403.21,   903.  ,   403.  ],
       [  955.53,   917.  ,   925.  ],
       [11110.67, 15000.  , 12500.  ],
       [ 1398.31,  1635.  ,  1398.  ],
       [  849.74,   998.  ,   818.  ],
       [  838.62,   998.  ,   808.  ],
       [  208.62,   277.  ,   205.  ],
       [  449.63,   590.  ,   440.  ],
       [  403.21,   553.  ,   403.  ],
       [10572.17, 16695.  , 10572.  ],
       [  836.58,   998.  ,   807.  ],
       [ 1042.  ,  1575.  ,  1042.  ],
       [  403.23,   603.  ,   403.  ],
       [10989.08, 15000.  , 10989.  ],
       [ 1016.21,  1575.  ,   966.  ],
       [11121.36, 14205.  , 10405.  ],
       [  721.72,   992.  ,   692.  ],
       [11128.36,  8900.  , 12004.  ],
       [12187.62, 15000.  , 15000.  ],
       [  403.28,   594.  ,   394.  ],
       [  284.4 ,   257.  ,   284.  ],
       [  667.07,   998.  ,   637.  ],
       [  403.28,   594.  ,   394.  ],
       [  403.21,   603.  ,   403.  ],
       [ 1022.51,  1023.  ,  1023.  ],
       [ 1308.  ,  1555.  ,  1255.  ],
       [11123.36, 14502.  , 10407.  ],
       [  403.32,   594.  ,   394.  ],
       [  403.28,   794.  ,   394.  ],
       [  403.26,   553.  ,   403.  ],
       [ 1144.01,  1712.  ,  1094.  ],
       [  843.26,  1100.  ,   843.  ],
       [  996.32,  1651.  ,   947.  ],
       [  784.81,   843.  ,   843.  ],
       [ 1351.53,   986.  ,  1301.  ],
       [  192.29,   127.  ,   127.  ],
       [  843.26,   843.  ,   843.  ],
       [ 1122.23,  1622.  ,  1073.  ],
       [  783.81,  1022.  ,  1022.  ],
       [  697.13,   917.  ,   835.  ],
       [10922.08, 15000.  , 10922.  ],
       [ 1022.51,  1023.  ,  1023.  ],
       [  403.22,   603.  ,   403.  ],
       [  286.26,   259.  ,   286.  ],
       [  204.54,   165.  ,   179.  ],
       [  404.68,   605.  ,   405.  ],
       [  204.54,   165.  ,   179.  ],
       [  643.89,   980.  ,   644.  ],
       [ 1060.3 ,  1911.  ,  1011.  ],
       [12210.97, 17056.  , 11538.  ],
       [  192.29,   127.  ,   189.  ],
       [11121.36, 12000.  , 10405.  ],
       [  584.08,   574.  ,   574.  ],
       [12210.97, 14055.  , 11538.  ],
       [  404.68,   595.  ,   395.  ],
       [ 1076.79,  1627.  ,  1027.  ],
       [  272.71,   257.  ,   268.  ],
       [  404.53,   595.  ,   395.  ],
       [11371.96,  8500.  , 12004.  ],
       [ 1390.46,  1911.  ,  1390.  ],
       [11198.36, 11618.  , 10482.  ],
       [12117.56, 15000.  , 15000.  ],
       [ 1071.66,  1653.  ,  1023.  ],
       [  934.83,   998.  ,   935.  ],
       [  955.53,   917.  ,   925.  ],
       [11806.17,  9433.  , 12720.  ],
       [  959.16,  1022.  ,  1022.  ],
       [ 1559.31,  1701.  ,  1559.  ],
       [  403.28,   544.  ,   394.  ],
       [ 1407.54,   986.  ,  1358.  ],
       [12144.74, 16500.  , 11481.  ],
       [ 1243.46,  1733.  ,  1243.  ],
       [ 1025.68,   999.  ,   975.  ],
       [  933.83,   935.  ,   934.  ],
       [  959.16,  1022.  ,  1022.  ],
       [  192.29,   127.  ,   127.  ],
       [  728.94,   998.  ,   697.  ],
       [  843.26,   902.  ,   902.  ],
       [  403.21,   894.  ,   394.  ],
       [  284.4 ,   257.  ,   284.  ],
       [ 1019.6 ,  1841.  ,   971.  ],
       [ 1156.53,  1099.  ,   847.  ],
       [  403.21,   903.  ,   403.  ],
       [ 1020.94,  1700.  ,   971.  ],
       [10489.17,  8000.  , 12720.  ],
       [ 1051.46,   999.  ,  1051.  ],
       [12187.62, 15500.  , 11518.  ],
       [12210.97, 17890.  , 11538.  ],
       [10437.17,  8300.  , 12720.  ],
       [10421.17, 16380.  , 10421.  ],
       [10639.17, 16695.  , 10639.  ],
       [11370.96, 10655.  , 10655.  ],
       [12144.74, 16500.  , 11481.  ],
       [  416.12,   916.  ,   416.  ],
       [  667.07,   998.  ,   637.  ],
       [  410.98,   596.  ,   402.  ],
       [12247.57, 14055.  , 11574.  ],
       [  403.41,   995.  ,   395.  ],
       [  662.65,   732.  ,   633.  ],
       [ 1018.61,  1019.  ,  1019.  ],
       [ 1351.53,  2037.  ,  1301.  ],
       [ 1018.61,  1019.  ,  1019.  ],
       [ 1078.49,  2061.  ,  1029.  ],
       [12210.97, 15000.  , 15000.  ],
       [ 1195.3 ,   999.  ,  1145.  ],
       [ 1051.46,  1799.  ,  1051.  ],
       [  403.21,   704.  ,   403.  ],
       [  783.81,  1022.  ,  1022.  ],
       [  697.13,   917.  ,  1022.  ],
       [  403.21,   603.  ,   403.  ],
       [  933.83,  1022.  ,  1022.  ],
       [  204.54,   165.  ,   165.  ],
       [  204.54,   165.  ,   179.  ],
       [12195.08, 15000.  , 11522.  ],
       [  738.68,   998.  ,   707.  ],
       [ 1133.2 ,  1712.  ,  1084.  ],
       [ 1022.51,  1023.  ,  1023.  ],
       [  410.98,   602.  ,   402.  ],
       [  403.28,   594.  ,   394.  ],
       [  956.17,   967.  ,   967.  ],
       [ 1151.59,   999.  ,  1101.  ],
       [  738.68,   998.  ,   707.  ],
       [12144.74, 16500.  , 11481.  ],
       [ 1379.  ,  1650.  ,  1326.  ],
       [ 1192.1 ,  1741.  ,  1142.  ],
       [ 1022.51,  1036.  ,  1023.  ],
       [ 1035.69,  1887.  ,   987.  ],
       [ 1062.2 ,  1813.  ,  1013.  ],
       [11334.36, 16695.  , 10618.  ],
       [10675.77, 15000.  , 10676.  ],
       [  272.71,   232.  ,   257.  ],
       [  759.48,  1255.  ,   835.  ],
       [  721.72,  1092.  ,   692.  ],
       [  758.48,  1022.  ,  1022.  ],
       [10922.08, 15000.  , 10922.  ],
       [  281.82,   256.  ,   282.  ],
       [ 1195.3 ,  1645.  ,  1145.  ],
       [ 1173.91,  1996.  ,  1125.  ],
       [  565.19,   541.  ,   565.  ],
       [10482.17, 12720.  , 12720.  ]])
           
print(X_train.shape)
print(y_train.shape)
           
(1532, 34)
(1532, 2)
           
def build_model():
    model = Sequential([
    Dense(64, activation=tf.nn.relu,
                       input_shape=(34,)),
    Dense(64, activation=tf.nn.relu),
    Dense(32,activation=tf.nn.relu),
    Dense(2)
  ])

    optimizer = tf.train.RMSPropOptimizer(0.001)
    model.compile(loss='mse',
                optimizer=optimizer,
                metrics=['mae'])
    return model 
model = build_model()

model.fit(X_train, y_train, epochs=250 ) 

pre = model.predict(X_test)
y_pre = scaler.inverse_transform(pre)
y_true = scaler.inverse_transform(y_test)
def metric(y_pre,y_true):
    return ((abs(y_pre - y_true))/y_true).mean()
metric(y_pre[:,1],y_true[:,1]),metric(y_pre[:,0],y_true[:,0])
           
y_pre
           
array([[ 1305.0636  ,  1507.7888  ],
       [ 1119.41    ,  1254.1056  ],
       [ 1259.6602  ,  1849.2323  ],
       [  390.25055 ,   544.7551  ],
       [  988.70953 ,  1039.439   ],
       [11550.204   , 14306.317   ],
       [  295.04327 ,   219.45985 ],
       [ 1016.6951  ,  1023.22375 ],
       [ 1171.6947  ,  1749.6289  ],
       [  406.2312  ,   562.3748  ],
       [  818.64935 ,   935.31177 ],
       [ 1114.8047  ,  1662.5568  ],
       [10845.84    , 12200.3545  ],
       [  205.70506 ,    46.931435],
       [ 1132.2059  ,  1426.7567  ],
       [ 1335.3221  ,  1701.8123  ],
       [10720.519   , 12028.937   ],
       [11124.664   , 13747.793   ],
       [10722.539   , 12024.162   ],
       [  661.68823 ,   829.11066 ],
       [ 1322.5798  ,   964.4997  ],
       [ 1282.2351  ,  1453.9495  ],
       [  973.0308  ,  1761.358   ],
       [ 1260.7831  ,  1856.5709  ],
       [  961.248   ,   954.48816 ],
       [ 1199.2781  ,  1275.2903  ],
       [ 1132.9553  ,  1721.3738  ],
       [10077.604   , 15202.591   ],
       [ 1263.4518  ,  1228.7592  ],
       [11841.785   , 13768.114   ],
       [ 1348.6885  ,  1729.8435  ],
       [  287.81235 ,   675.29407 ],
       [  955.52765 ,   858.11395 ],
       [11585.42    , 14450.994   ],
       [ 1304.4938  ,  1284.9801  ],
       [  692.91644 ,   900.112   ],
       [  754.9513  ,   925.387   ],
       [  219.10901 ,   354.815   ],
       [  403.4972  ,   521.494   ],
       [  345.08737 ,   524.5736  ],
       [10315.215   , 14608.666   ],
       [  759.0576  ,   905.8869  ],
       [  934.2337  ,  1464.7003  ],
       [  391.29416 ,   544.47797 ],
       [10646.697   , 14169.346   ],
       [  810.6588  ,  1220.3186  ],
       [10760.828   , 11381.781   ],
       [  715.417   ,   855.57294 ],
       [10687.356   , 11897.006   ],
       [11726.676   , 13376.778   ],
       [  381.6232  ,   528.40497 ],
       [  306.56622 ,   197.14157 ],
       [  661.25415 ,   915.4901  ],
       [  392.4997  ,   534.3512  ],
       [  389.0867  ,   529.63855 ],
       [ 1016.6951  ,  1023.22375 ],
       [ 1321.645   ,  1815.649   ],
       [10848.822   , 12329.249   ],
       [  395.7551  ,   554.63855 ],
       [  296.96558 ,   652.3022  ],
       [  378.1468  ,   482.56137 ],
       [ 1076.8687  ,  1470.7391  ],
       [  859.632   ,   995.7925  ],
       [ 1070.0586  ,  1818.0604  ],
       [  784.28235 ,   900.36554 ],
       [ 1287.6227  ,   803.91223 ],
       [  167.56787 ,    50.185482],
       [  852.94135 ,   934.8866  ],
       [ 1113.5623  ,  1746.2261  ],
       [  771.55426 ,  1060.8016  ],
       [  688.74603 ,   802.35236 ],
       [10536.23    , 13771.425   ],
       [ 1016.6951  ,  1023.22375 ],
       [  387.1082  ,   536.6619  ],
       [  317.06055 ,   313.65598 ],
       [  195.64595 ,   136.30536 ],
       [  407.58157 ,   498.853   ],
       [  193.31119 ,   131.62497 ],
       [  704.5961  ,   916.2591  ],
       [ 1039.0177  ,  1694.2977  ],
       [10010.148   , 15782.555   ],
       [   84.45573 ,   -22.312496],
       [10720.519   , 12028.937   ],
       [  600.10913 ,   780.37946 ],
       [11735.283   , 13838.804   ],
       [  370.94952 ,   533.9112  ],
       [ 1104.217   ,  1681.901   ],
       [  410.5102  ,   323.52554 ],
       [  395.55917 ,   555.08057 ],
       [11051.969   , 11019.764   ],
       [ 1304.4296  ,  2089.2825  ],
       [10687.356   , 11897.006   ],
       [11204.106   , 14356.169   ],
       [ 1090.8058  ,  1623.682   ],
       [  779.8621  ,   912.7374  ],
       [  865.35443 ,   871.8132  ],
       [10983.814   , 11827.601   ],
       [  941.58026 ,  1039.4353  ],
       [ 1341.8043  ,  1757.6267  ],
       [  384.6042  ,   487.86346 ],
       [ 1270.9099  ,  1317.3862  ],
       [11344.095   , 15409.088   ],
       [ 1081.2527  ,  1775.4283  ],
       [ 1043.2506  ,  1458.5288  ],
       [  906.9569  ,  1032.8734  ],
       [  941.58026 ,  1039.4353  ],
       [  168.72546 ,    88.58223 ],
       [  677.1466  ,   801.54736 ],
       [  824.58514 ,   973.99835 ],
       [  341.29715 ,   759.86884 ],
       [  295.1775  ,   192.15799 ],
       [ 1070.5411  ,  1723.2421  ],
       [ 1123.2665  ,  1324.7588  ],
       [  404.6457  ,   767.98004 ],
       [ 1008.0169  ,  1668.6987  ],
       [11424.964   , 14311.51    ],
       [ 1688.7523  ,  1482.9012  ],
       [10760.828   , 11381.781   ],
       [10760.828   , 11381.781   ],
       [10845.84    , 12200.3545  ],
       [10823.565   , 11785.428   ],
       [10551.548   , 15356.034   ],
       [10533.989   ,  8814.531   ],
       [11358.319   , 13913.23    ],
       [  384.53366 ,   256.5992  ],
       [  662.0324  ,   922.28284 ],
       [  406.2312  ,   562.3748  ],
       [11466.327   , 13189.79    ],
       [  361.49445 ,   772.25684 ],
       [  674.89465 ,   827.57526 ],
       [  994.6179  ,  1005.362   ],
       [ 1326.381   ,  1720.1023  ],
       [  994.6179  ,  1005.362   ],
       [ 1132.2418  ,  1721.4011  ],
       [11402.407   , 14015.907   ],
       [ 1177.5433  ,  1367.53    ],
       [  922.4901  ,  1678.8469  ],
       [  380.63635 ,   567.5858  ],
       [  801.8367  ,   940.9808  ],
       [  681.05835 ,   845.0781  ],
       [  374.6362  ,   488.06863 ],
       [  906.9569  ,  1032.8734  ],
       [  152.37051 ,   114.99935 ],
       [  217.90707 ,   117.12175 ],
       [11158.72    , 11000.613   ],
       [  817.6034  ,   945.20374 ],
       [ 1151.7108  ,  1174.148   ],
       [ 1016.6951  ,  1023.22375 ],
       [  449.79416 ,   520.836   ],
       [  387.07602 ,   527.4897  ],
       [  916.1636  ,  1066.8954  ],
       [ 1113.9105  ,  1305.6373  ],
       [  750.6172  ,   908.88306 ],
       [11524.594   , 15545.861   ],
       [ 1234.288   ,  1659.1686  ],
       [ 1186.4081  ,  2229.7869  ],
       [  895.10815 ,   889.861   ],
       [ 1041.1367  ,  1690.8862  ],
       [ 1073.9315  ,  1716.0935  ],
       [10551.548   , 15356.034   ],
       [11544.789   , 14051.969   ],
       [  289.68622 ,   187.19872 ],
       [  749.97955 ,   893.7389  ],
       [  720.2503  ,  1015.39825 ],
       [  749.97955 ,   893.7389  ],
       [10085.853   , 14431.91    ],
       [  273.8957  ,   173.83041 ],
       [ 1197.1199  ,  1716.9813  ],
       [ 1199.0336  ,  1806.2025  ],
       [  516.913   ,   447.1036  ],
       [10797.725   , 11302.929   ]], dtype=float32)
           
y_true
           
array([[ 1517.47,   999.  ],
       [ 1165.07,  1712.  ],
       [ 1073.76,  1810.  ],
       [  403.28,   594.  ],
       [ 1018.61,  1019.  ],
       [12154.16, 15000.  ],
       [  272.71,   232.  ],
       [ 1022.51,  1023.  ],
       [ 1213.12,  1733.  ],
       [  462.33,   603.  ],
       [  934.83,  1036.  ],
       [ 1051.46,   999.  ],
       [11334.36, 11553.  ],
       [  204.54,   165.  ],
       [ 1133.76,  1055.  ],
       [ 1398.31,  1635.  ],
       [12187.62, 15500.  ],
       [11234.96, 11819.  ],
       [10811.97, 13800.  ],
       [  669.11,   917.  ],
       [ 1407.54,   986.  ],
       [ 1351.53,  1635.  ],
       [  953.72,  1927.  ],
       [ 1230.71,  1682.  ],
       [ 1070.39,  1036.  ],
       [ 1347.54,  1651.  ],
       [ 1076.4 ,  1827.  ],
       [10988.58, 15000.  ],
       [ 1470.67,   999.  ],
       [12720.98, 14342.  ],
       [ 1407.54,  1575.  ],
       [  403.21,   903.  ],
       [  955.53,   917.  ],
       [11110.67, 15000.  ],
       [ 1398.31,  1635.  ],
       [  849.74,   998.  ],
       [  838.62,   998.  ],
       [  208.62,   277.  ],
       [  449.63,   590.  ],
       [  403.21,   553.  ],
       [10572.17, 16695.  ],
       [  836.58,   998.  ],
       [ 1042.  ,  1575.  ],
       [  403.23,   603.  ],
       [10989.08, 15000.  ],
       [ 1016.21,  1575.  ],
       [11121.36, 14205.  ],
       [  721.72,   992.  ],
       [11128.36,  8900.  ],
       [12187.62, 15000.  ],
       [  403.28,   594.  ],
       [  284.4 ,   257.  ],
       [  667.07,   998.  ],
       [  403.28,   594.  ],
       [  403.21,   603.  ],
       [ 1022.51,  1023.  ],
       [ 1308.  ,  1555.  ],
       [11123.36, 14502.  ],
       [  403.32,   594.  ],
       [  403.28,   794.  ],
       [  403.26,   553.  ],
       [ 1144.01,  1712.  ],
       [  843.26,  1100.  ],
       [  996.32,  1651.  ],
       [  784.81,   843.  ],
       [ 1351.53,   986.  ],
       [  192.29,   127.  ],
       [  843.26,   843.  ],
       [ 1122.23,  1622.  ],
       [  783.81,  1022.  ],
       [  697.13,   917.  ],
       [10922.08, 15000.  ],
       [ 1022.51,  1023.  ],
       [  403.22,   603.  ],
       [  286.26,   259.  ],
       [  204.54,   165.  ],
       [  404.68,   605.  ],
       [  204.54,   165.  ],
       [  643.89,   980.  ],
       [ 1060.3 ,  1911.  ],
       [12210.97, 17056.  ],
       [  192.29,   127.  ],
       [11121.36, 12000.  ],
       [  584.08,   574.  ],
       [12210.97, 14055.  ],
       [  404.68,   595.  ],
       [ 1076.79,  1627.  ],
       [  272.71,   257.  ],
       [  404.53,   595.  ],
       [11371.96,  8500.  ],
       [ 1390.46,  1911.  ],
       [11198.36, 11618.  ],
       [12117.56, 15000.  ],
       [ 1071.66,  1653.  ],
       [  934.83,   998.  ],
       [  955.53,   917.  ],
       [11806.17,  9433.  ],
       [  959.16,  1022.  ],
       [ 1559.31,  1701.  ],
       [  403.28,   544.  ],
       [ 1407.54,   986.  ],
       [12144.74, 16500.  ],
       [ 1243.46,  1733.  ],
       [ 1025.68,   999.  ],
       [  933.83,   935.  ],
       [  959.16,  1022.  ],
       [  192.29,   127.  ],
       [  728.94,   998.  ],
       [  843.26,   902.  ],
       [  403.21,   894.  ],
       [  284.4 ,   257.  ],
       [ 1019.6 ,  1841.  ],
       [ 1156.53,  1099.  ],
       [  403.21,   903.  ],
       [ 1020.94,  1700.  ],
       [10489.17,  8000.  ],
       [ 1051.46,   999.  ],
       [12187.62, 15500.  ],
       [12210.97, 17890.  ],
       [10437.17,  8300.  ],
       [10421.17, 16380.  ],
       [10639.17, 16695.  ],
       [11370.96, 10655.  ],
       [12144.74, 16500.  ],
       [  416.12,   916.  ],
       [  667.07,   998.  ],
       [  410.98,   596.  ],
       [12247.57, 14055.  ],
       [  403.41,   995.  ],
       [  662.65,   732.  ],
       [ 1018.61,  1019.  ],
       [ 1351.53,  2037.  ],
       [ 1018.61,  1019.  ],
       [ 1078.49,  2061.  ],
       [12210.97, 15000.  ],
       [ 1195.3 ,   999.  ],
       [ 1051.46,  1799.  ],
       [  403.21,   704.  ],
       [  783.81,  1022.  ],
       [  697.13,   917.  ],
       [  403.21,   603.  ],
       [  933.83,  1022.  ],
       [  204.54,   165.  ],
       [  204.54,   165.  ],
       [12195.08, 15000.  ],
       [  738.68,   998.  ],
       [ 1133.2 ,  1712.  ],
       [ 1022.51,  1023.  ],
       [  410.98,   602.  ],
       [  403.28,   594.  ],
       [  956.17,   967.  ],
       [ 1151.59,   999.  ],
       [  738.68,   998.  ],
       [12144.74, 16500.  ],
       [ 1379.  ,  1650.  ],
       [ 1192.1 ,  1741.  ],
       [ 1022.51,  1036.  ],
       [ 1035.69,  1887.  ],
       [ 1062.2 ,  1813.  ],
       [11334.36, 16695.  ],
       [10675.77, 15000.  ],
       [  272.71,   232.  ],
       [  759.48,  1255.  ],
       [  721.72,  1092.  ],
       [  758.48,  1022.  ],
       [10922.08, 15000.  ],
       [  281.82,   256.  ],
       [ 1195.3 ,  1645.  ],
       [ 1173.91,  1996.  ],
       [  565.19,   541.  ],
       [10482.17, 12720.  ]])
           
pre = model.predict(X_train)
y_pre = scaler.inverse_transform(pre)
bp_cost,bp_price = y_pre[:,0],y_pre[:1]
           

lgb

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)']].values
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)
           

成本

import lightgbm as lgb
           
pa_cost = dict( boosting_type='gbdt', class_weight=None,colsample_bytree=1.0, importance_type='split', lambda_l1=0.1,
    lambda_l2=0.7, learning_rate=0.1, max_depth=8, min_child_samples=12,min_child_weight=0.001, min_split_gain=0.0, n_estimators=138,
    n_jobs=-1, num_leaves=23, objective='regression', random_state=None,reg_alpha=0.0, reg_lambda=0.0, silent=True, subsample=1.0,
     subsample_for_bin=200000, subsample_freq=0)
gbm_C = lgb.LGBMRegressor(**pa_cost)

gbm_C.fit(X_train,y_train[:,0])

lgb_cost= gbm_C.predict(X_test)
metric(lgb_cost,y_test[:,0])
           
(array([ 1336.7367348 ,  1205.01886514,  1158.59693144,   398.50064806,
         1019.53135012, 11486.02914477,   295.09027858,  1010.90913263,
         1195.07007816,   437.21554971,   913.18290132,  1077.71872023,
        11134.53600982,   208.80590423,  1136.65839201,  1428.40775178,
        11360.30548702, 11325.41091263, 12293.95692894,   692.24956525,
         1357.53378681,  1328.39142045,   989.50925442,  1212.45287116,
          928.19056292,  1362.39615631,  1098.12249353, 11494.2569339 ,
         1305.60770558, 12709.38646157,  1411.86558407,   394.61542724,
          736.22554229, 12642.83000254,  1467.13462084,   804.72682991,
          817.79449777,   206.50378336,   446.8005589 ,   395.38125158,
        10989.39798235,   817.09824155,  1045.6527113 ,   407.74949128,
        10758.14127786,  1087.85104846, 11390.92737044,   732.82020867,
        11427.3250634 , 12162.6307659 ,   396.99754199,   279.6225437 ,
          703.76189297,   403.76293473,   401.6489462 ,  1010.90913263,
         1216.04819109, 11595.52956309,   398.50064806,   402.25982866,
          396.03458416,  1132.618276  ,   857.31244502,  1026.1917888 ,
          779.77921854,  1331.09665016,   200.96262324,   842.24650675,
         1127.69464839,   778.37332493,   685.69068321, 10993.58743916,
         1010.90913263,   402.43692652,   277.22049673,   207.63355143,
          396.45519348,   212.03201108,   651.23461765,  1117.24154057,
        10710.89004843,   160.81902844, 11360.30548702,   630.57622203,
        11551.59637353,   396.57693268,  1094.28706708,   259.74999714,
          398.50064806, 11429.31230831,  1353.90017076, 11427.3250634 ,
        12021.9641796 ,  1094.28706708,   934.57299766,   736.22554229,
        11540.40624398,   956.34601954,  1379.23038628,   405.14424315,
         1364.97811248, 12121.08335157,  1078.07225674,  1071.9300919 ,
          935.1603461 ,   956.34601954,   189.28216467,   697.30206539,
          843.65240036,   402.35171511,   280.04852842,  1044.61651585,
         1136.65839201,   401.28078448,  1073.57696216, 11592.36667515,
         1103.56632182, 11390.92737044, 11390.92737044, 11134.53600982,
        10919.65703971, 11168.64898932, 11265.43559821, 12145.64474858,
          403.50981664,   695.57329373,   437.21554971, 11616.30874192,
          435.29183432,   677.10241847,  1019.53135012,  1350.54663615,
         1019.53135012,  1132.37866764, 12006.03681906,  1213.4501359 ,
         1078.07225674,   393.80655956,   778.37332493,   689.09862171,
          395.11720145,   935.1603461 ,   223.20669827,   207.29441618,
        12201.57996762,   813.80536316,  1171.41890809,  1010.90913263,
          420.84313132,   403.76293473,  1018.31660039,  1156.32951312,
          727.12427074, 12411.08387213,  1372.39666027,  1163.3875832 ,
          990.17366628,  1139.80824078,  1133.51079256, 11168.64898932,
        10760.57014033,   279.21263808,   757.00874914,   732.27786016,
          757.00874914, 10885.64437008,   271.46495977,  1189.01746867,
         1180.48453291,   508.34055772, 10706.04133637]),
 array([ 1517.47,  1165.07,  1073.76,   403.28,  1018.61, 12154.16,
          272.71,  1022.51,  1213.12,   462.33,   934.83,  1051.46,
        11334.36,   204.54,  1133.76,  1398.31, 12187.62, 11234.96,
        10811.97,   669.11,  1407.54,  1351.53,   953.72,  1230.71,
         1070.39,  1347.54,  1076.4 , 10988.58,  1470.67, 12720.98,
         1407.54,   403.21,   955.53, 11110.67,  1398.31,   849.74,
          838.62,   208.62,   449.63,   403.21, 10572.17,   836.58,
         1042.  ,   403.23, 10989.08,  1016.21, 11121.36,   721.72,
        11128.36, 12187.62,   403.28,   284.4 ,   667.07,   403.28,
          403.21,  1022.51,  1308.  , 11123.36,   403.32,   403.28,
          403.26,  1144.01,   843.26,   996.32,   784.81,  1351.53,
          192.29,   843.26,  1122.23,   783.81,   697.13, 10922.08,
         1022.51,   403.22,   286.26,   204.54,   404.68,   204.54,
          643.89,  1060.3 , 12210.97,   192.29, 11121.36,   584.08,
        12210.97,   404.68,  1076.79,   272.71,   404.53, 11371.96,
         1390.46, 11198.36, 12117.56,  1071.66,   934.83,   955.53,
        11806.17,   959.16,  1559.31,   403.28,  1407.54, 12144.74,
         1243.46,  1025.68,   933.83,   959.16,   192.29,   728.94,
          843.26,   403.21,   284.4 ,  1019.6 ,  1156.53,   403.21,
         1020.94, 10489.17,  1051.46, 12187.62, 12210.97, 10437.17,
        10421.17, 10639.17, 11370.96, 12144.74,   416.12,   667.07,
          410.98, 12247.57,   403.41,   662.65,  1018.61,  1351.53,
         1018.61,  1078.49, 12210.97,  1195.3 ,  1051.46,   403.21,
          783.81,   697.13,   403.21,   933.83,   204.54,   204.54,
        12195.08,   738.68,  1133.2 ,  1022.51,   410.98,   403.28,
          956.17,  1151.59,   738.68, 12144.74,  1379.  ,  1192.1 ,
         1022.51,  1035.69,  1062.2 , 11334.36, 10675.77,   272.71,
          759.48,   721.72,   758.48, 10922.08,   281.82,  1195.3 ,
         1173.91,   565.19, 10482.17]))
           
(171,)
           

线路指导价

pa_price = dict( min_child_weight=0.0, boosting_type='gbdt', class_weight=None,
       colsample_bytree=1.0, importance_type='split', learning_rate=0.1,
       max_depth=-1, min_child_samples=10,
       min_split_gain=0.0, n_estimators=57, n_jobs=-1, num_leaves=31,
       objective='regression', random_state=None, reg_alpha=0.0,
       reg_lambda=0.0, silent=True, subsample=1.0,
       subsample_for_bin=200000, subsample_freq=0)
gbm_P = lgb.LGBMRegressor(**pa_price) 

gbm_P.fit(X_train,y_train[:,1])
lgb_price = gbm_P.predict(X_test)
metric(lgb_price,y_test[:,1])
           
(array([ 1201.70103279,  1406.64809669,  1782.63628775,   599.65940608,
         1046.38234841, 16282.55523129,   263.63530256,  1063.02504542,
         1808.695056  ,   612.6639388 ,  1013.20218562,  1552.27465217,
        13079.79043257,   194.99562723,  1103.79020382,  1446.83397412,
        13037.15062028, 11259.00892911, 14430.58653229,   953.29048158,
         1094.36799243,  1509.90938107,  1752.66111256,  1826.1026242 ,
         1009.52204282,  1601.22434475,  1742.26264587, 15306.19925105,
         1201.70103279, 14587.5950772 ,  1735.18129396,   634.86515226,
          951.74363164, 14956.90559903,  1437.11850709,  1013.88813315,
         1020.13906128,   290.40960251,   612.6639388 ,   634.86515226,
        16458.91942109,   995.86504373,  1521.50505002,   597.74244106,
        14972.93277324,  1445.57774623, 12250.63345355,  1030.07330024,
        12394.18286045, 14827.93833129,   599.65940608,   263.63530256,
         1011.24924301,   599.65940608,   599.65940608,  1063.02504542,
         1743.87444909, 14831.92669505,   599.65940608,   634.86515226,
          613.13920531,  1309.22706683,  1016.23431666,  1739.44939289,
         1015.7781514 ,  1102.22171871,   165.43177432,  1015.7781514 ,
         1810.91575861,  1014.89355494,   938.18409724, 14900.02133472,
         1063.02504542,   597.74244106,   265.43838235,   188.34706585,
          592.15814173,   188.34706585,   990.15831727,  1803.89344661,
        13634.38451893,   168.66531064, 13037.15062028,   730.27808661,
        15390.40869279,   634.86515226,  1798.90918169,   293.59074551,
          599.65940608, 11087.16940049,  1944.08660094, 12394.18286045,
        15467.73499682,  1796.19875094,  1015.91261638,   951.74363164,
        12892.07368131,  1023.23239974,  1676.40026223,   613.13920531,
         1104.77535659, 16682.62648373,  1739.44939289,  1415.49472744,
         1029.18327888,  1023.23239974,   165.43177432,   990.76150577,
         1014.89355494,   757.12311875,   263.63530256,  1747.05737697,
         1103.94177581,   680.89551886,  1739.44939289, 14174.61205782,
         1008.10783634, 12250.63345355, 12250.63345355, 13079.79043257,
        10058.70970361, 16251.24099279,  9511.06495235, 16818.3073124 ,
          672.38140205,  1004.98329366,   612.6639388 , 13788.26377648,
          711.28678539,  1034.47154707,  1024.26409367,  1840.10212138,
         1024.26409367,  1742.26264587, 15088.79553804,  1409.35852744,
         1691.60278878,   673.04942096,  1015.7781514 ,   939.06869369,
          613.13920531,  1029.18327888,   159.72709096,   188.34706585,
        14369.51048219,  1015.91261638,  1417.46847845,  1063.02504542,
          605.16267445,   599.65940608,  1030.70418486,  1214.07758821,
         1001.9629587 , 16653.21358247,  1739.53768403,  1846.71671107,
         1024.43596928,  1734.80839753,  1734.80839753, 16251.24099279,
        16472.46367379,   263.63530256,  1022.5366351 ,  1029.18870378,
         1022.5366351 , 14919.93459553,   264.03308151,  1742.26264587,
         1810.91575861,   582.13315948, 10879.87222072]),
 array([  999.,  1712.,  1810.,   594.,  1019., 15000.,   232.,  1023.,
         1733.,   603.,  1036.,   999., 11553.,   165.,  1055.,  1635.,
        15500., 11819., 13800.,   917.,   986.,  1635.,  1927.,  1682.,
         1036.,  1651.,  1827., 15000.,   999., 14342.,  1575.,   903.,
          917., 15000.,  1635.,   998.,   998.,   277.,   590.,   553.,
        16695.,   998.,  1575.,   603., 15000.,  1575., 14205.,   992.,
         8900., 15000.,   594.,   257.,   998.,   594.,   603.,  1023.,
         1555., 14502.,   594.,   794.,   553.,  1712.,  1100.,  1651.,
          843.,   986.,   127.,   843.,  1622.,  1022.,   917., 15000.,
         1023.,   603.,   259.,   165.,   605.,   165.,   980.,  1911.,
        17056.,   127., 12000.,   574., 14055.,   595.,  1627.,   257.,
          595.,  8500.,  1911., 11618., 15000.,  1653.,   998.,   917.,
         9433.,  1022.,  1701.,   544.,   986., 16500.,  1733.,   999.,
          935.,  1022.,   127.,   998.,   902.,   894.,   257.,  1841.,
         1099.,   903.,  1700.,  8000.,   999., 15500., 17890.,  8300.,
        16380., 16695., 10655., 16500.,   916.,   998.,   596., 14055.,
          995.,   732.,  1019.,  2037.,  1019.,  2061., 15000.,   999.,
         1799.,   704.,  1022.,   917.,   603.,  1022.,   165.,   165.,
        15000.,   998.,  1712.,  1023.,   602.,   594.,   967.,   999.,
          998., 16500.,  1650.,  1741.,  1036.,  1887.,  1813., 16695.,
        15000.,   232.,  1255.,  1092.,  1022., 15000.,   256.,  1645.,
         1996.,   541., 12720.]))
           
lgb_price.shape
           
(1532,)
           
y_test.shape
np.zeros((171,)) 
           
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.])
           

stacking

融合验证

for train_index,test_index in kf.split(X_train):
    X_train_1, X_test_1 = X_train[train_index], X_train[test_index]
    y_train_1, y_test_1 = y_train[train_index], y_train[test_index]

def stacking_pre(model,X_train,y_train,X_test):
    kf =KFold(n_splits=5,shuffle=False)
    y_pre_list= []
    y_test_pre = np.zeros((X_test.shape[0],))
    for train_index,test_index in kf.split(X_train):
        X_train_1, X_test_1 = X_train[train_index], X_train[test_index]
        y_train_1, y_test_1 = y_train[train_index], y_train[test_index]
        model.fit(X_train_1,y_train_1)
        pre_y = model.predict(X_test_1)
        a = pre_y.tolist()
        y_pre_list = y_pre_list + a
        y_test_pre += model.predict(X_test)
    return np.array(y_pre_list),y_test_pre/5

def stacking_pre_bp(X_train,y_train,X_test):
    kf =KFold(n_splits=5,shuffle=False)
    c_list = []
    p_list = []
    y_test_pre = np.zeros((X_test.shape[0],2))
    for train_index, test_index in kf.split(X_train):
        X_train_1, X_test_1 = X_train[train_index], X_train[test_index]
        y_train_1, y_test_1 = y_train[train_index], y_train[test_index]
        model = build_model()
        model.fit(X_train_1,y_train_1,epochs=250)
        pre_y = model.predict(X_test_1)
        pre_y = scaler.inverse_transform(pre_y)
        a = pre_y[:,0].tolist()
        b = pre_y[:,1].tolist()
        c_list = c_list + a
        p_list = p_list + b
        pre_test_y = model.predict(X_test)
        pre_test_ye = scaler.inverse_transform(pre_test_y)
        y_test_pre += pre_test_ye
    return np.array(c_list),np.array(p_list),y_test_pre/5

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)']].values
scaler = MinMaxScaler()
y = scaler.fit_transform(y)
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)

new_train_c_2,new_train_p_2,new_test_cp = stacking_pre_bp(X_train,y_train,X_test)

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)']].values
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)

new_train_c_1,new_test_c_1 = stacking_pre(gbm_C,X_train,y_train[:,0],X_test)
new_train_p_1,new_test_p_1 = stacking_pre(gbm_P,X_train,y_train[:,1],X_test)

from sklearn.linear_model import LinearRegression

lr = LinearRegression()
lr.fit(np.column_stack((new_train_c_1,new_train_c_2)),y_train[:,0])

lr.predict(np.column_stack((new_test_c_1 ,new_test_cp[:,0])))

metric(lr.predict(np.column_stack((new_test_c_1 ,new_test_cp[:,0]))),y_test[:,0]) # 0.0311

lr = LinearRegression()
lr.fit(np.column_stack((new_train_p_1,new_train_p_2)),y_train[:,1])
lr.predict(np.column_stack((new_test_p_1 ,new_test_cp[:,1])))

metric(lr.predict(np.column_stack((new_test_p_1 ,new_test_cp[:,1]))),y_test[:,1]) # 0.0977
           

融合结果

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)']].values
scaler = MinMaxScaler()
y = scaler.fit_transform(y)
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)

new_train_c_2,new_train_p_2,new_test_cp = stacking_pre_bp(X_train,y_train,X_test=test.values)

X = train[train.columns[:34]].values
y = train[['线路总成本', '线路价格(不含税)']].values
X_train,X_test,y_train,y_test= train_test_split(X,y,test_size=0.1,random_state=2020)

new_train_c_1,new_test_c_1 = stacking_pre(gbm_C,X_train,y_train[:,0],test.values)
new_train_p_1,new_test_p_1 = stacking_pre(gbm_P,X_train,y_train[:,1],test.values)

lr = LinearRegression()
lr.fit(np.column_stack((new_train_c_1,new_train_c_2)),y_train[:,0])
pre_cost_stacking = lr.predict(np.column_stack((new_test_c_1 ,new_test_cp[:,0])))

lr = LinearRegression()
lr.fit(np.column_stack((new_train_p_1,new_train_p_2)),y_train[:,1])
pre_pre_stacking = lr.predict(np.column_stack((new_test_p_1 ,new_test_cp[:,1])))

pre_cost_stacking

pre_pre_stacking 

result_stacking = pd.DataFrame({'线路价格':pre_pre_stacking,'成本':pre_cost_stacking})
result_stacking.index = test.index

result_stacking.to_csv('result_stackinge.csv',encoding='utf-8-sig')
           
Epoch 1/250
1225/1225 [==============================] - 1s 679us/step - loss: 0.1895 - mae: 0.3338
Epoch 2/250
1225/1225 [==============================] - 0s 182us/step - loss: 0.0772 - mae: 0.2119
Epoch 3/250
1225/1225 [==============================] - 0s 169us/step - loss: 0.0124 - mae: 0.0815
Epoch 4/250
1225/1225 [==============================] - 0s 229us/step - loss: 0.0046 - mae: 0.0470
Epoch 5/250
1225/1225 [==============================] - 0s 178us/step - loss: 0.0031 - mae: 0.0366
Epoch 6/250
1225/1225 [==============================] - 0s 201us/step - loss: 0.0027 - mae: 0.0322
Epoch 7/250
1225/1225 [==============================] - 0s 196us/step - loss: 0.0026 - mae: 0.0297
Epoch 8/250
1225/1225 [==============================] - 0s 196us/step - loss: 0.0022 - mae: 0.0273
Epoch 9/250
1225/1225 [==============================] - 0s 193us/step - loss: 0.0021 - mae: 0.0250
Epoch 10/250
1225/1225 [==============================] - 0s 169us/step - loss: 0.0021 - mae: 0.0254
Epoch 11/250
1225/1225 [==============================] - 0s 167us/step - loss: 0.0021 - mae: 0.0247
Epoch 12/250
1225/1225 [==============================] - 0s 204us/step - loss: 0.0020 - mae: 0.0240
Epoch 13/250
1225/1225 [==============================] - 0s 353us/step - loss: 0.0018 - mae: 0.0226
Epoch 14/250
1225/1225 [==============================] - 0s 224us/step - loss: 0.0018 - mae: 0.0215
Epoch 15/250
1225/1225 [==============================] - 0s 288us/step - loss: 0.0018 - mae: 0.0212
Epoch 16/250
1225/1225 [==============================] - 0s 282us/step - loss: 0.0018 - mae: 0.0217
Epoch 17/250
1225/1225 [==============================] - 0s 199us/step - loss: 0.0018 - mae: 0.0215
Epoch 18/250
1225/1225 [==============================] - 1s 697us/step - loss: 0.0018 - mae: 0.0210
Epoch 19/250
1225/1225 [==============================] - 0s 212us/step - loss: 0.0016 - mae: 0.0197
Epoch 20/250
1225/1225 [==============================] - 0s 209us/step - loss: 0.0016 - mae: 0.0195
Epoch 21/250
1225/1225 [==============================] - 0s 158us/step - loss: 0.0015 - mae: 0.0186
Epoch 22/250
1225/1225 [==============================] - 0s 312us/step - loss: 0.0018 - mae: 0.0205
Epoch 23/250
1225/1225 [==============================] - 0s 180us/step - loss: 0.0017 - mae: 0.0196
           

26/1226 [] - 0s 263us/step - loss: 0.0011 - mae: 0.0131

Epoch 245/250

1226/1226 [] - 0s 194us/step - loss: 0.0011 - mae: 0.0127

Epoch 246/250

1226/1226 [] - 0s 194us/step - loss: 0.0012 - mae: 0.0137

Epoch 247/250

1226/1226 [] - 0s 182us/step - loss: 0.0011 - mae: 0.0126

Epoch 248/250

1226/1226 [] - 0s 215us/step - loss: 0.0011 - mae: 0.0129

Epoch 249/250

1226/1226 [] - 0s 204us/step - loss: 0.0011 - mae: 0.0131

Epoch 250/250

1226/1226 [==============================] - 0s 221us/step - loss: 0.0011 - mae: 0.0129