天天看點

用java實作神經網絡的組成單元:感覺器

感覺器

最近看到一篇部落格,由淺入深地講解了機器學習中的神經網絡,但是由于它其中的實作代碼是python,我也想更加深入的了解一波兒,就用java也實作了一套,實作功能是一樣的。

感覺器模型結構圖:

用java實作神經網絡的組成單元:感覺器
Notice:通過設定一個值恒為一的節點,可以把誤差和權重的計算統一起來,其中值恒唯一的節點對應的權重表示誤差(bias)。

這裡的代碼是實作對于一個感覺器的神經網絡,輸入是二維向量,輸出一個數(0,1),通過簡單的學習實作對于and 和 or 運算的拟合

import java.util.Arrays;
public class Perceptron implements IPerceptron{
	//輸入節點
    private double[] inputArray;
	//權重,最後一個輸入節點(1.d)對應的是誤差
    private double[] weights;
	//輸出節點
    private double output;
	//輸入節點個數
    private int inputLength;
	//學習系數
    private double rate;

    public Perceptron(int inputLength, double rate) {
        this.inputLength = inputLength;
        this.inputArray = new double[inputLength+1];
        //設定誤差項
        this.inputArray[inputLength] = 1.d;
        this.weights = new double[inputLength+1];
        this.output = 0.d;
        this.rate = rate;
    }
	//激活函數:表示為階躍函數
    public double activotor(double x) {
        return Double.compare(x,0.d) > 0? 1.d:0.d;
    }
	//前向運算
    public void predict() {
        double sum = 0;
        for (int i = 0; i < inputArray.length; i++) {
            sum += inputArray[i]*weights[i];
        }
        this.output = activotor(sum);
    }
	//前向運算重載方法
    public double predict(double[] inputs) {
        for (int i = 0; i < inputLength; i++) {
            inputArray[i] = inputs[i];
        }
        predict();
        return this.output;
    }
	//後向運算
    public void backwards(double delta) {
        for (int i = 0; i < inputArray.length; i++) {
            weights[i] += rate*delta*inputArray[i];
        }
        System.out.println(Arrays.toString(weights));
    }
	//訓練方法
    public void train(double[][] trainData,double[] labels, int batch){
        for (int i = 0; i < batch; i++) {
            for (int j = 0; j < trainData.length; j++) {
                for (int k = 0; k < trainData[j].length; k++) {
                   inputArray[k] = trainData[j][k];
                }
                predict();
                backwards(labels[j]-output);
            }
        }
    }

    public static void main(String[] args) {
        double[][] trainData = {{0,0},{0,1},{1,0},{1,1}};
        double[] labels = {0,0,0,1};
        Perceptron perceptron = new Perceptron(2,0.1);
        perceptron.train(trainData,labels,10);
        for (int i = 0; i < trainData.length; i++) {
            double predict = perceptron.predict(trainData[i]);
            System.out.println(Arrays.toString(trainData[i]) +": "+predict);
        }
    }
}

           

運作結果:

用java實作神經網絡的組成單元:感覺器

感悟:對于神經網絡的代碼debug比較有難度,純手打的代碼也不好維護,在項目中還是用成熟的架構比較好,手打隻是為了更加深入的了解。