天天看點

lightgbm_predict4j:LightGBM線上預測的java實作

LightGBM是微軟推出的一款開源boosting工具,現在已經成為各類機器學習競賽常用的一大利器。不過由于LightGBM是c++編寫的,并且其預測功能的主要使用方式是指令行調用處理批量資料,比較難用于線上實時預測。lightgbm_predict4j是針對其預測代碼用java重新實作的一個小工具,在用LightGBM離線生成模型之後,可以用lightgbm_predict4j加載模型,然後在java應用中用來做線上實時的預測。項目位址:https://github.com/lyg5623/lightgbm_predict4j

用法很簡單,比如生成的模型檔案為LightGBM_model.txt,以下為預測代碼示例:

import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URLDecoder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Test;
import org.lightgbm.predict4j.v2.Boosting;
import org.lightgbm.predict4j.v2.OverallConfig;
import org.lightgbm.predict4j.v2.Predictor;
import org.lightgbm.predict4j.SparseVector;

/**
 * @author lyg5623
 */
public class UseageTest {
    //your model path
    private static String modelPath = "LightGBM_model.txt";

    @Test
    public void test() throws FileNotFoundException, IOException {
        String path = UseageTest.class.getClassLoader().getResource(modelPath).getPath();
      //your model path
        path = URLDecoder.decode(path, "utf8");

        Boosting boosting = Boosting.createBoosting(path);
        // predict config, just like predict.conf in LightGBM
        Map<String, String> map = new HashMap<String, String>();
        OverallConfig config = new OverallConfig();
        config.set(map);
        Predictor predictor =
                new Predictor(boosting, config.io_config.num_iteration_predict, config.io_config.is_predict_raw_score,
                        config.io_config.is_predict_leaf_index, config.io_config.pred_early_stop,
                        config.io_config.pred_early_stop_freq, config.io_config.pred_early_stop_margin);

        // your data to predict
        int[] indices = {2, 6, 9};
        double[] values = {0.2, 0.4, 0.7};

        SparseVector v = new SparseVector(values, indices);
        List<Double> predicts = predictor.predict(v);
        System.out.println("predict values " + predicts.toString());

    }

}