天天看點

Spark2 Linear Regression線性回歸案例(參數調優)

 回歸正則化方法(Lasso,Ridge和ElasticNet)在高維和資料集變量之間多重共線性情況下運作良好。

數學上,ElasticNet被定義為L1和L2正則化項的凸組合:

Spark2 Linear Regression線性回歸案例(參數調優)

通過适當設定α,ElasticNet包含L1和L2正則化作為特殊情況。例如,如果用參數α設定為1來訓練線性回歸模型,則其等價于Lasso模型。另一方面,如果α被設定為0,則訓練的模型簡化為ridge回歸模型。 

RegParam:lambda>=0

ElasticNetParam:alpha in [0, 1]

導入包

import

org.apache.spark.sql.SparkSession

import

org.apache.spark.sql.Dataset

import

org.apache.spark.sql.Row

import

org.apache.spark.sql.DataFrame

import

org.apache.spark.sql.Column

import

org.apache.spark.sql.DataFrameReader

import

org.apache.spark.rdd.RDD

import

org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import

org.apache.spark.sql.Encoder

import

org.apache.spark.sql.DataFrameStatFunctions

import

org.apache.spark.sql.functions.

_

import

org.apache.spark.ml.linalg.Vectors

import

org.apache.spark.ml.feature.VectorAssembler

import

org.apache.spark.ml.evaluation.RegressionEvaluator

import

org.apache.spark.ml.regression.LinearRegression

導入樣本資料

// Population人口,

// Income收入水準,

// Illiteracy文盲率,

// LifeExp,

// Murder謀殺率,

// HSGrad,

// Frost結霜天數(溫度在冰點以下的平均天數) ,

// Area州面積

val

spark 

=

SparkSession.builder().appName(

"Spark Linear Regression"

).config(

"spark.some.config.option"

"some-value"

).getOrCreate()

// For implicit conversions like converting RDDs to DataFrames

import

spark.implicits.

_

val

dataList

:

List[(Double, Double, Double, Double, Double, Double, Double, Double)] 

=

List(

(

3615

3624

2.1

69.05

15.1

41.3

20

50708

),

(

365

6315

1.5

69.31

11.3

66.7

152

566432

),

(

2212

4530

1.8

70.55

7.8

58.1

15

113417

),

(

2110

3378

1.9

70.66

10.1

39.9

65

51945

),

(

21198

5114

1.1

71.71

10.3

62.6

20

156361

),

(

2541

4884

0.7

72.06

6.8

63.9

166

103766

),

(

3100

5348

1.1

72.48

3.1

56

139

4862

),

(

579

4809

0.9

70.06

6.2

54.6

103

1982

),

(

8277

4815

1.3

70.66

10.7

52.6

11

54090

),

(

4931

4091

2

68.54

13.9

40.6

60

58073

),

(

868

4963

1.9

73.6

6.2

61.9

6425

),

(

813

4119

0.6

71.87

5.3

59.5

126

82677

),

(

11197

5107

0.9

70.14

10.3

52.6

127

55748

),

(

5313

4458

0.7

70.88

7.1

52.9

122

36097

),

(

2861

4628

0.5

72.56

2.3

59

140

55941

),

(

2280

4669

0.6

72.58

4.5

59.9

114

81787

),

(

3387

3712

1.6

70.1

10.6

38.5

95

39650

),

(

3806

3545

2.8

68.76

13.2

42.2

12

44930

),

(

1058

3694

0.7

70.39

2.7

54.7

161

30920

),

(

4122

5299

0.9

70.22

8.5

52.3

101

9891

),

(

5814

4755

1.1

71.83

3.3

58.5

103

7826

),

(

9111

4751

0.9

70.63

11.1

52.8

125

56817

),

(

3921

4675

0.6

72.96

2.3

57.6

160

79289

),

(

2341

3098

2.4

68.09

12.5

41

50

47296

),

(

4767

4254

0.8

70.69

9.3

48.8

108

68995

),

(

746

4347

0.6

70.56

5

59.2

155

145587

),

(

1544

4508

0.6

72.6

2.9

59.3

139

76483

),

(

590

5149

0.5

69.03

11.5

65.2

188

109889

),

(

812

4281

0.7

71.23

3.3

57.6

174

9027

),

(

7333

5237

1.1

70.93

5.2

52.5

115

7521

),

(

1144

3601

2.2

70.32

9.7

55.2

120

121412

),

(

18076

4903

1.4

70.55

10.9

52.7

82

47831

),

(

5441

3875

1.8

69.21

11.1

38.5

80

48798

),

(

637

5087

0.8

72.78

1.4

50.3

186

69273

),

(

10735

4561

0.8

70.82

7.4

53.2

124

40975

),

(

2715

3983

1.1

71.42

6.4

51.6

82

68782

),

(

2284

4660

0.6

72.13

4.2

60

44

96184

),

(

11860

4449

1

70.43

6.1

50.2

126

44966

),

(

931

4558

1.3

71.9

2.4

46.4

127

1049

),

(

2816

3635

2.3

67.96

11.6

37.8

65

30225

),

(

681

4167

0.5

72.08

1.7

53.3

172

75955

),

(

4173

3821

1.7

70.11

11

41.8

70

41328

),

(

12237

4188

2.2

70.9

12.2

47.4

35

262134

),

(

1203

4022

0.6

72.9

4.5

67.3

137

82096

),

(

472

3907

0.6

71.64

5.5

57.1

168

9267

),

(

4981

4701

1.4

70.08

9.5

47.8

85

39780

),

(

3559

4864

0.6

71.72

4.3

63.5

32

66570

),

(

1799

3617

1.4

69.48

6.7

41.6

100

24070

),

(

4589

4468

0.7

72.48

3

54.5

149

54464

),

(

376

4566

0.6

70.29

6.9

62.9

173

97203

))

val

data 

=

dataList.toDF(

"Population"

"Income"

"Illiteracy"

"LifeExp"

"Murder"

"HSGrad"

"Frost"

"Area"

)

建立線性回歸模型

val

colArray 

=

Array(

"Population"

"Income"

"Illiteracy"

"LifeExp"

"HSGrad"

"Frost"

"Area"

)

val

assembler 

=

new

VectorAssembler().setInputCols(colArray).setOutputCol(

"features"

)

val

vecDF

:

DataFrame 

=

assembler.transform(data)

// 建立模型,預測謀殺率Murder

// 設定線性回歸參數

val

lr

1

=

new

LinearRegression()

val

lr

2

=

lr

1

.setFeaturesCol(

"features"

).setLabelCol(

"Murder"

).setFitIntercept(

true

)

// RegParam:正則化

val

lr

3

=

lr

2

.setMaxIter(

10

).setRegParam(

0.3

).setElasticNetParam(

0.8

)

val

lr 

=

lr

3

// Fit the model

val

lrModel 

=

lr.fit(vecDF)

// 輸出模型全部參數

lrModel.extractParamMap()

// Print the coefficients and intercept for linear regression

println(s

"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"

)

val

predictions 

=

lrModel.transform(vecDF)

predictions.selectExpr(

"Murder"

"round(prediction,1) as prediction"

).show

// Summarize the model over the training set and print out some metrics

val

trainingSummary 

=

lrModel.summary

println(s

"numIterations: ${trainingSummary.totalIterations}"

)

println(s

"objectiveHistory: ${trainingSummary.objectiveHistory.toList}"

)

trainingSummary.residuals.show()

println(s

"RMSE: ${trainingSummary.rootMeanSquaredError}"

)

println(s

"r2: ${trainingSummary.r2}"

)

代碼執行結果

// 輸出模型全部參數

lrModel.extractParamMap()

res

15

:

org.apache.spark.ml.param.ParamMap 

=

{

linReg

_

2

ba

28140

e

39

a-elasticNetParam

:

0.8

,

linReg

_

2

ba

28140

e

39

a-featuresCol

:

features,

linReg

_

2

ba

28140

e

39

a-fitIntercept

:

true

,

linReg

_

2

ba

28140

e

39

a-labelCol

:

Murder,

linReg

_

2

ba

28140

e

39

a-maxIter

:

10

,

linReg

_

2

ba

28140

e

39

a-predictionCol

:

prediction,

linReg

_

2

ba

28140

e

39

a-regParam

:

0.3

,

linReg

_

2

ba

28140

e

39

a-solver

:

auto,

linReg

_

2

ba

28140

e

39

a-standardization

:

true

,

linReg

_

2

ba

28140

e

39

a-tol

:

1.0

E-

6

}

// Print the coefficients and intercept for linear regression

println(s

"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}"

)

Coefficients

:

[

1.36662199778084

E-

4

,

0.0

,

1.1834384307116244

,-

1.4580829641757522

,

0.0

,-

0.010686434270049252

,

4.051355050528196

E-

6

] Intercept

:

109.589659881471

val

predictions 

=

lrModel.transform(vecDF)

predictions

:

org.apache.spark.sql.DataFrame 

=

[Population

:

double, Income

:

double ... 

8

more fields]

predictions.selectExpr(

"Murder"

"round(prediction,1) as prediction"

).show

+------+----------+

|Murder|prediction|

+------+----------+

|  

15.1

|      

11.9

|

|  

11.3

|      

11.0

|

|   

7.8

|       

9.5

|

|  

10.1

|       

8.6

|

|  

10.3

|       

9.6

|

|   

6.8

|       

4.3

|

|   

3.1

|       

4.2

|

|   

6.2

|       

7.5

|

|  

10.7

|       

9.3

|

|  

13.9

|      

12.3

|

|   

6.2

|       

4.7

|

|   

5.3

|       

4.6

|

|  

10.3

|       

8.8

|

|   

7.1

|       

6.6

|

|   

2.3

|       

3.5

|

|   

4.5

|       

3.9

|

|  

10.6

|       

8.9

|

|  

13.2

|      

13.2

|

|   

2.7

|       

6.3

|

|   

8.5

|       

7.8

|

+------+----------+

only showing top 

20

rows

// Summarize the model over the training set and print out some metrics

val

trainingSummary 

=

lrModel.summary

trainingSummary

:

org.apache.spark.ml.regression.LinearRegressionTrainingSummary 

=

org.apache.spark.ml.regression.LinearRegressionTrainingSummary

@

68

a

83

d

76

println(s

"numIterations: ${trainingSummary.totalIterations}"

)

numIterations

:

11

println(s

"objectiveHistory: ${trainingSummary.objectiveHistory.toList}"

)

objectiveHistory

:

List(

0.49000000000000016

0.3919242806809093

0.19908078426904946

0.1901453492751914

0.17981874256031405

0.17878173084286247

0.1787617816935607

0.17875431854661641

0.1

7874702637141196

0.17874512271568685

0.1787449876896829

)

trainingSummary.residuals.show()

+--------------------+

|           residuals|

+--------------------+

|  

3.2200068116713023

|

|  

0.2745518816306607

|

| -

1.6535887417767414

|

|   

1.485762696757325

|

|  

0.6509766532389172

|

|   

2.457688146554534

|

| -

1.0675250558261182

|

| -

1.2879164685248439

|

|  

1.3672723619868314

|

|  

1.6125000289597242

|

|   

1.532060517905248

|

|  

0.6931301635074645

|

|  

1.5163001982000175

|

0.46227066807431605

|

| -

1.2044058248740273

|

|  

0.6032541157521649

|

|     

1.7201545753635

|

|-

0.01942937427384

...|

|  -

3.632947522687547

|

|  

0.7077675962948007

|

+--------------------+

only showing top 

20

rows

println(s

"RMSE: ${trainingSummary.rootMeanSquaredError}"

)

RMSE

:

1.6663615527314546

println(s

"r2: ${trainingSummary.r2}"

)

r

2

:

0.7920794990832152

模型調優,用Train-Validation Split

val

colArray 

=

Array(

"Population"

"Income"

"Illiteracy"

"LifeExp"

"HSGrad"

"Frost"

"Area"

)

val

vecDF

:

DataFrame 

=

new

VectorAssembler().setInputCols(colArray).setOutputCol(

"features"

).transform(data)

val

Array(trainingDF, testDF) 

=

vecDF.randomSplit(Array(

0.9

0.1

), seed 

=

12345

)

// 建立模型,預測謀殺率Murder,設定線性回歸參數

val

lr 

=

new

LinearRegression().setFeaturesCol(

"features"

).setLabelCol(

"Murder"

).fit(trainingDF)

// 設定管道

val

pipeline 

=

new

Pipeline().setStages(Array(lr))

// 建立參數網格

val

paramGrid 

=

new

ParamGridBuilder().addGrid(lr.fitIntercept).addGrid(lr.elasticNetParam, Array(

0.0

0.5

1.0

)).addGrid(lr.maxIter, Array(

10

100

)).build()

// 選擇(prediction, true label),計算測試誤差。

// 注意RegEvaluator.isLargerBetter,評估的路徑成本是大的好,還是小的好,系統會自動識别

val

RegEvaluator 

=

new

RegressionEvaluator().setLabelCol(lr.getLabelCol).setPredictionCol(lr.getPredictionCol).setMetricName(

"rmse"

)

val

trainValidationSplit 

=

new

TrainValidationSplit().setEstimator(pipeline).setEvaluator(RegEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(

0.8

// 資料分割比例

// Run train validation split, and choose the best set of parameters.

val

tvModel 

=

trainValidationSplit.fit(trainingDF)

// 檢視模型全部參數

tvModel.extractParamMap()

tvModel.getEstimatorParamMaps.length

tvModel.getEstimatorParamMaps.foreach { println } 

// 參數組合的集合

tvModel.getEvaluator.extractParamMap() 

// 評估的參數

tvModel.getEvaluator.isLargerBetter 

// 評估的路徑成本是大的好,還是小的好

tvModel.getTrainRatio

// 用最好的參數組合,做出預測

tvModel.transform(testDF).select(

"features"

"Murder"

"prediction"

).show()

調優代碼執行結果

// 檢視模型全部參數

tvModel.extractParamMap()

res

45

:

org.apache.spark.ml.param.ParamMap 

=

{

tvs

_

5

de

7

d

3

dd

1977

-estimator

:

pipeline

_

062

a

1

dffe

557

,

tvs

_

5

de

7

d

3

dd

1977

-estimatorParamMaps

:

[Lorg.apache.spark.ml.param.ParamMap;

@

60298

de

1

,

tvs

_

5

de

7

d

3

dd

1977

-evaluator

:

regEval

_

05204824

acb

9

,

tvs

_

5

de

7

d

3

dd

1977

-seed

:

-

1772833110

,

tvs

_

5

de

7

d

3

dd

1977

-trainRatio

:

0.8

}

tvModel.getEstimatorParamMaps.length

res

46

:

Int 

=

12

tvModel.getEstimatorParamMaps.foreach { println } 

// 參數組合的集合

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.5

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.5

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.5

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

0.5

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

1.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

1.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

true

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

1.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

10

}

{

linReg

_

75628

a

5554

b

4

-elasticNetParam

:

1.0

,

linReg

_

75628

a

5554

b

4

-fitIntercept

:

false

,

linReg

_

75628

a

5554

b

4

-maxIter

:

100

}

tvModel.getEvaluator.extractParamMap() 

// 評估的參數

res

48

:

org.apache.spark.ml.param.ParamMap 

=

{

regEval

_

05204824

acb

9

-labelCol

:

Murder,

regEval

_

05204824

acb

9

-metricName

:

rmse,

regEval

_

05204824

acb

9

-predictionCol

:

prediction

}

tvModel.getEvaluator.isLargerBetter 

// 評估的路徑成本是大的好,還是小的好

res

49

:

Boolean 

=

false

tvModel.getTrainRatio

res

50

:

Double 

=

0.8

tvModel.transform(testDF).select(

"features"

"Murder"

"prediction"

).show()

+--------------------+------+------------------+

|            features|Murder|        prediction|

+--------------------+------+------------------+

|[

1058.0

,

3694.0

,

....|   

2.7

6.917232043935343

|

|[

2341.0

,

3098.0

,

2

....|  

12.5

|

14.760329005533478

|

|[

472.0

,

3907.0

,

0.6

...|   

5.5

4.182074651181182

|

|[

812.0

,

4281.0

,

0.7

...|   

3.3

4.915905572667441

|

|[

2816.0

,

3635.0

,

2

....|  

11.6

|

14.219231061596304

|

|[

4589.0

,

4468.0

,

....|   

3.0

3.483554528704758

|

+--------------------+------+------------------+

ML

繼續閱讀