生存回歸(加速失效時間模型)
算法介紹:
在spark.ml中,我們實施加速失效時間模型(Acceleratedfailure time),對于截尾資料它是一個參數化生存回歸的模型。它描述了一個有對數生存時間的模型,是以它也常被稱為生存分析的對數線性模型。與比例危險模型不同,因AFT模型中每個執行個體對目标函數的貢獻是獨立的,其更容易并行化。
給定協變量的值
,對于
可能的右截尾的随機生存時間
,AFT模型下的似然函數如下:
其中
是訓示器表明事件i發生了,即有無檢測到。使
,則對數似然函數為以下形式:
其中
是基線生存函數,
是對應的密度函數。
最常用的AFT模型基于韋伯分布的生存時間,生存時間的韋伯分布對應于生存時間對數的極值分布,
函數以及
函數如下:
生存時間服從韋伯分布的AFT模型的對數似然函數如下:
由于最小化對數似然函數的負數等于最大化後驗機率,是以我們要優化的損失函數為
,分别對
以及
求導:
可以證明AFT模型是一個凸優化問題,即是說找到凸函數
的最小值取決于系數向量
以及尺度參數的對數
。在工具中實施的優化算法為L-BFGS。
*當使用無攔截的連續非零列訓練AFTSurvivalRegressionModel時,Spark MLlib為連續非零列輸出零系數。這種處理與R中的生存函數survreg不同。
參數:
censorCol:
類型:字元串型。
含義:檢查器列名。
featuresCol:
類型:字元串型。
含義:特征列名。
fitIntercept:
類型:布爾型。
含義:是否訓練攔截對象。
labelCol:
類型:字元串型。
含義:标簽列名。
maxIter:
類型:整數型。
含義:疊代次數(>=0)。
quantileProbabilities:
類型:雙精度數組型。
含義:分位數機率數組。
quantilesCol:
類型:字元串型。
含義:分位數列名。
stepSize:
類型:雙精度型。
含義:每次疊代優化步長。
tol:
類型:雙精度型。
含義:疊代算法的收斂性。
示例:
Scala:
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression
val training = spark.createDataFrame(Seq(
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
(4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
.setQuantilesCol("quantiles")
val model = aft.fit(training)
// Print the coefficients, intercept and scale parameter for AFT survival regression
println(s"Coefficients: ${model.coefficients} Intercept: " +
s"${model.intercept} Scale: ${model.scale}")
model.transform(training).show(false)
Java:
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
List<Row> data = Arrays.asList(
RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)),
RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)),
RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)),
RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)),
RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226))
);
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> training = spark.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
.setQuantilesCol("quantiles");
AFTSurvivalRegressionModel model = aft.fit(training);
// Print the coefficients, intercept and scale parameter for AFT survival regression
System.out.println("Coefficients: " + model.coefficients() + " Intercept: "
+ model.intercept() + " Scale: " + model.scale());
model.transform(training).show(false);
Python:
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors
training = spark.createDataFrame([
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
(4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
quantilesCol="quantiles")
model = aft.fit(training)
# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)