天天看點

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

生存回歸(加速失效時間模型)

算法介紹:

        在spark.ml中,我們實施加速失效時間模型(Acceleratedfailure time),對于截尾資料它是一個參數化生存回歸的模型。它描述了一個有對數生存時間的模型,是以它也常被稱為生存分析的對數線性模型。與比例危險模型不同,因AFT模型中每個執行個體對目标函數的貢獻是獨立的,其更容易并行化。

         給定協變量的值

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 ,對于

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 可能的右截尾的随機生存時間

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 ,AFT模型下的似然函數如下:

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

其中

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 是訓示器表明事件i發生了,即有無檢測到。使

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 ,則對數似然函數為以下形式:

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

其中

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 是基線生存函數,

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 是對應的密度函數。

最常用的AFT模型基于韋伯分布的生存時間,生存時間的韋伯分布對應于生存時間對數的極值分布,

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 函數以及

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 函數如下:

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)
生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

生存時間服從韋伯分布的AFT模型的對數似然函數如下:

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

由于最小化對數似然函數的負數等于最大化後驗機率,是以我們要優化的損失函數為

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 ,分别對

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 以及

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

 求導:

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)
生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

可以證明AFT模型是一個凸優化問題,即是說找到凸函數

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

的最小值取決于系數向量

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

以及尺度參數的對數

生存回歸(加速失效時間模型)算法原理及Spark MLlib調用執行個體(Scala/Java/python)

。在工具中實施的優化算法為L-BFGS。

*當使用無攔截的連續非零列訓練AFTSurvivalRegressionModel時,Spark MLlib為連續非零列輸出零系數。這種處理與R中的生存函數survreg不同。

參數:

censorCol:

類型:字元串型。

含義:檢查器列名。

featuresCol:

類型:字元串型。

含義:特征列名。

fitIntercept:

類型:布爾型。

含義:是否訓練攔截對象。

labelCol:

類型:字元串型。

含義:标簽列名。

maxIter:

類型:整數型。

含義:疊代次數(>=0)。

quantileProbabilities:

類型:雙精度數組型。

含義:分位數機率數組。

quantilesCol:

類型:字元串型。

含義:分位數列名。

stepSize:

類型:雙精度型。

含義:每次疊代優化步長。

tol:

類型:雙精度型。

含義:疊代算法的收斂性。

示例:

Scala:

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression

val training = spark.createDataFrame(Seq(
  (1.218, 1.0, Vectors.dense(1.560, -0.605)),
  (2.949, 0.0, Vectors.dense(0.346, 2.158)),
  (3.627, 0.0, Vectors.dense(1.380, 0.231)),
  (0.273, 1.0, Vectors.dense(0.520, 1.151)),
  (4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
  .setQuantilesCol("quantiles")

val model = aft.fit(training)

// Print the coefficients, intercept and scale parameter for AFT survival regression
println(s"Coefficients: ${model.coefficients} Intercept: " +
  s"${model.intercept} Scale: ${model.scale}")
model.transform(training).show(false)
           

Java:

import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

List<Row> data = Arrays.asList(
  RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)),
  RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)),
  RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)),
  RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)),
  RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226))
);
StructType schema = new StructType(new StructField[]{
  new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
  new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
  new StructField("features", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> training = spark.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
  .setQuantilesCol("quantiles");

AFTSurvivalRegressionModel model = aft.fit(training);

// Print the coefficients, intercept and scale parameter for AFT survival regression
System.out.println("Coefficients: " + model.coefficients() + " Intercept: "
  + model.intercept() + " Scale: " + model.scale());
model.transform(training).show(false);
           

Python:

from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors

training = spark.createDataFrame([
    (1.218, 1.0, Vectors.dense(1.560, -0.605)),
    (2.949, 0.0, Vectors.dense(0.346, 2.158)),
    (3.627, 0.0, Vectors.dense(1.380, 0.231)),
    (0.273, 1.0, Vectors.dense(0.520, 1.151)),
    (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
                            quantilesCol="quantiles")

model = aft.fit(training)

# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)
           

繼續閱讀