天天看點

AI入門之手寫數字識别模型java方式Dense全連接配接神經網絡實作

作者:京東雲開發者

前言:授人以魚不如授人以漁.先學會用,在學原理,在學創造,可能一輩子用不到這種能力,但是不能不具備這種能力。這篇文章主要是介紹算法入門Helloword之手寫圖檔識别模型java中如何實作以及部分解釋。目前大家對于人工智能-機器學習-神經網絡的文章都是基于python語言的,對于擅長java的後端小夥伴想要去了解就不是特别友好,是以這裡給大家介紹一下如何在java中實作,打開新世界的大門。以下為本人個人了解如有錯誤歡迎指正

一、目标:使用MNIST資料集訓練手寫數字圖檔識别模型

在實作一個模型的時候我們要準備哪些知識體系:

1.機器學習基礎:包括監督學習、無監督學習、強化學習等基本概念。

2.資料處理與分析:資料清洗、特征工程、資料可視化等。

3.程式設計語言:如Python,用于實作機器學習算法。

4.數學基礎:線性代數、機率統計、微積分等數學知識。

5.機器學習算法:線性回歸、決策樹、神經網絡、支援向量機等算法。

6.深度學習架構:如TensorFlow、PyTorch等,用于建構和訓練深度學習模型。

7.模型評估與優化:交叉驗證、超參數調優、模型評估名額等。

8.實踐經驗:通過實際項目和競賽積累經驗,不斷提升模型學習能力。

這裡的機器學習HelloWorld是手寫圖檔識别用的是TensorFlow架構

主要需要:

1.了解手寫圖檔的資料集,訓練集是什麼樣的資料(60000,28,28) 、訓練集的标簽是什麼樣的(1)

2.了解激活函數的作用

3.正向傳遞和反向傳播的作用以及實作

4.訓練模型和儲存模型

5.加載儲存的模型使用

二、java代碼與python代碼對比分析

因為python代碼解釋網上已經有很多了,這裡不在重複解釋

1.資料集的加載

python中

def load_data(dpata_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
    return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim)  # 資料集的次元
print(train_x.shape)  # 資料集的形狀
print(len(train_x))  # 資料集的大小
print(train_x)  # 資料集
print("---檢視單個資料")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
print("---檢視單個資料")
print(train_y[3])           
AI入門之手寫數字識别模型java方式Dense全連接配接神經網絡實作

java中

SimpleMnist.class

private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
    private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
    private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
    private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
//加載資料
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);           

MnistDataset.class

/**
     * @param trainingImagesArchive 訓練圖檔路徑
     * @param trainingLabelsArchive 訓練标簽路徑
     * @param testImagesArchive     測試圖檔路徑
     * @param testLabelsArchive     測試标簽路徑
     */
    public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);
            trainingImages.slice(sliceFrom(0));
            trainingLabels.slice(sliceTo(0));
            // 切片操作
            Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
            ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
            ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
            if (index >= 0) {
                return new MnistDataset(trainingImages,trainingLabels,validationImage,validationLable,testImages,testLabels);
            } else {
                return null;
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }  
    private static ByteNdArray readArchive(String archiveName) throws IOException {
        System.out.println("archiveName = " + archiveName);
        DataInputStream archiveStream = new DataInputStream(new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
        );
        archiveStream.readShort(); // first two bytes are always 0
        byte magic = archiveStream.readByte();
        if (magic != TYPE_UBYTE) {
            throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
        }
        int numDims = archiveStream.readByte();
        long[] dimSizes = new long[numDims];
        int size = 1;  // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
        for (int i = 0; i < dimSizes.length; ++i) {
            dimSizes[i] = archiveStream.readInt();
            size *= dimSizes[i];
        }
        byte[] bytes = new byte[size];
        archiveStream.readFully(bytes);
        return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
    }
    /**
     * Mnist 資料集構造器
     */
    private MnistDataset(ByteNdArray trainingImages, ByteNdArray trainingLabels,ByteNdArray validationImages,ByteNdArray validationLabels,ByteNdArray testImages,ByteNdArray testLabels
    ) {
        this.trainingImages = trainingImages;
        this.trainingLabels = trainingLabels;
        this.validationImages = validationImages;
        this.validationLabels = validationLabels;
        this.testImages = testImages;
        this.testLabels = testLabels;
        this.imageSize = trainingImages.get(0).shape().size();
        System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
        System.out.println("資料集的次元:" + trainingImages.rank());
        System.out.println("資料集的形狀 = " + trainingImages.shape());
        System.out.println("資料集的大小 = " + trainingImages.shape().get(0));
        System.out.println("檢視單個資料 = " + trainingImages.get(0));
    }           
AI入門之手寫數字識别模型java方式Dense全連接配接神經網絡實作

2.模型建構

python中

model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))  # 添加Flatten層說明輸入資料的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))  # 添加隐含層,為全連接配接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))  # 添加輸出層,為全連接配接層,10個節點,softmax激活函數
print("列印模型結構")
# 使用 summary 列印模型結構
print('\n', model.summary())  # 檢視網絡結構和參數資訊
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])           

java中

SimpleMnist.class

Ops tf = Ops.create(graph);
        // Create placeholders and variables, which should fit batches of an unknown number of images
        //建立占位符和變量,這些占位符和變量應适合未知數量的圖像批次
        Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
        Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);

        // Create weights with an initial value of 0
        // 建立初始值為 0 的權重
        Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
        Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
        
        // Create biases with an initial value of 0
        //建立初始值為 0 的偏置
        Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
        Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));

        // Predict the class of each image in the batch and compute the loss
        //使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,并加上偏置項 biases。
        //wx+b
        MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
        Add<TFloat32> add = tf.math.add(matMul, biases);
        //Softmax 是一個常用的激活函數,它将輸入轉換為表示機率分布的輸出。對于輸入向量中的每個元素,Softmax 函數會計算指數,
        //并對所有元素求和,然後将每個元素的指數除以總和,最終得到一個機率分布。這通常用于多分類問題,以輸出每個類别的機率
        Softmax<TFloat32> softmax = tf.nn.softmax(add);

        // 建立一個計算交叉熵的Mean對象
        Mean<TFloat32> crossEntropy =
                tf.math.mean(  // 計算張量的平均值
                        tf.math.neg(  // 計算張量的負值
                                tf.reduceSum(  // 計算張量的和
                                        tf.math.mul(labels, tf.math.log(softmax)),  //計算标簽和softmax預測的對數乘積
                                        tf.array(1)  // 在指定軸上求和
                                )
                        ),
                        tf.array(0)  // 在指定軸上求平均值
                );

        // Back-propagate gradients to variables for training
        //使用梯度下降優化器來最小化交叉熵損失函數。首先,建立了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
        Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
        Op minimize = optimizer.minimize(crossEntropy);           

3.訓練模型

python中

history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)           

java中

SimpleMnist.class

// Train the model
            for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
                try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
                     TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
                    // 建立會話運作器
                    session.runner()
                            // 添加要最小化的目标
                            .addTarget(minimize)
                            // 通過feed方法将圖像資料輸入到模型中
                            .feed(images.asOutput(), batchImages)
                            // 通過feed方法将标簽資料輸入到模型中
                            .feed(labels.asOutput(), batchLabels)
                            // 運作會話
                            .run();
                }
            }           

4.模型評估

python中

test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2)  # 每次疊代輸出一條記錄,來評價該模型是否有比較好的泛化能力
print('Test 損失: %.3f' % test_loss)
print('Test 精确度: %.3f' % test_acc)           

java中

SimpleMnist.class

// Test the model
            ImageBatch testBatch = dataset.testBatch();
            try (TFloat32 testImages = preprocessImages(testBatch.images());
                 TFloat32 testLabels = preprocessLabels(testBatch.labels());
                 // 定義一個TFloat32類型的變量accuracyValue,用于存儲計算得到的準确率值
                 TFloat32 accuracyValue = (TFloat32) session.runner()
                         // 從會話中擷取準确率值
                         .fetch(accuracy)
                         .fetch(predicted)
                         .fetch(expected)
                         // 将images作為輸入,testImages作為資料進行喂養
                         .feed(images.asOutput(), testImages)
                         // 将labels作為輸入,testLabels作為資料進行喂養
                         .feed(labels.asOutput(), testLabels)
                         // 運作會話并擷取結果
                         .run()
                         // 擷取第一個結果并存儲在accuracyValue中
                         .get(0)) {
                System.out.println("Accuracy: " + accuracyValue.getFloat());
            }           

5.儲存模型

python中

# 使用save_model儲存完整模型
# save_model(model, '/media/cfs/使用者ERP名稱/ea/saved_model', save_format='pb')
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')           

java中

SimpleMnist.class

// 儲存模型
            SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
            Signature.Builder builder = Signature.builder();
            builder.input("images", images);
            builder.input("labels", labels);
            builder.output("accuracy", accuracy);
            builder.output("expected", expected);
            builder.output("predicted", predicted);
            Signature signature = builder.build();
            SessionFunction sessionFunction = SessionFunction.create(signature, session);
            exporter.withFunction(sessionFunction);
            exporter.export();           

6.加載模型

python中

# 加載.pb模型檔案
    global load_model
    load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
    load_model.summary()
    demo = tensorflow.reshape(test_x, (1, 28, 28))
    input_data = np.array(demo)  # 準備你的輸入資料
    input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
    predictValue = load_model.predict(input_data)
    print("predictValue")
    print(predictValue)
    y_pred = np.argmax(predictValue)
    print('标簽值:' + str(test_y) + '\n預測值:' + str(y_pred))
    return y_pred, test_y,           

java中

SimpleMnist.class

//加載模型并預測
    public void loadModel(String exportDir) {
        // load saved model
        SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
        try {
            printSignature(model);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        ByteNdArray validationImages = dataset.getValidationImages();
        ByteNdArray validationLabels = dataset.getValidationLabels();
        TFloat32 testImages = preprocessImages(validationImages);
        System.out.println("testImages = " + testImages.shape());
        TFloat32 testLabels = preprocessLabels(validationLabels);
        System.out.println("testLabels = " + testLabels.shape());
        Result run = model.session().runner()
                .feed("Placeholder:0", testImages)
                .feed("Placeholder_1:0", testLabels)
                .fetch("ArgMax:0")
                .fetch("ArgMax_1:0")
                .fetch("Mean_1:0")
                .run();
        // 處理輸出
        Optional<Tensor> tensor1 = run.get("ArgMax:0");
        Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
        Optional<Tensor> tensor3 = run.get("Mean_1:0");
        TInt64 predicted = (TInt64) tensor1.get();
        Long predictedValue = predicted.getObject(0);
        System.out.println("predictedValue = " + predictedValue);
        TInt64 expected = (TInt64) tensor2.get();
        Long expectedValue = expected.getObject(0);
        System.out.println("expectedValue = " + expectedValue);
        TFloat32 accuracy = (TFloat32) tensor3.get();
        System.out.println("accuracy = " + accuracy.getFloat());
    }
    //列印模型資訊
    private static void printSignature(SavedModelBundle model) throws Exception {
        MetaGraphDef m = model.metaGraphDef();
        SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
        int numInputs = sig.getInputsCount();
        int i = 1;
        System.out.println("MODEL SIGNATURE");
        System.out.println("Inputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
        }
        int numOutputs = sig.getOutputsCount();
        i = 1;
        System.out.println("Outputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
        }
    }           

三、完整的python代碼

本工程使用環境為

Python: 3.7.9

https://www.python.org/downloads/windows/

Anaconda: Python 3.11 Anaconda3-2023.09-0-Windows-x86_64

https://www.anaconda.com/download#downloads

tensorflow:2.0.0

直接從anaconda下安裝

mnistTrainDemo.py

import gzip
import os.path
import tensorflow as tensorflow
from tensorflow import keras
# 可視化 image
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import save_model

# 加載資料
# mnist = keras.datasets.mnist
# mnistData = mnist.load_data() #Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- unknown url type: https
"""
這裡可以直接使用
mnist = keras.datasets.mnist
mnistData = mnist.load_data() 加載資料,但是有的時候不成功,是以使用本地加載資料
"""
def load_data(data_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))

    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)

    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)

    return (train_x, train_y), (test_x, test_y)

(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim)  # 資料集的次元
print(train_x.shape)  # 資料集的形狀
print(len(train_x))  # 資料集的大小
print(train_x)  # 資料集
print("---檢視單個資料")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
# 可視化image圖檔、一副image的資料
# plt.imshow(train_x[0].reshape(28, 28), cmap="binary")
# plt.show()
print("---檢視單個資料")
print(train_y[0])

# 資料預處理
# 歸一化、并轉換為tensor張量,資料類型為float32.  ---歸一化也可能造成識别率低
# train_x, test_x = tensorflow.cast(train_x / 255.0, tensorflow.float32), tensorflow.cast(test_x / 255.0,
#                                                                                         tensorflow.float32),
# train_y, test_y = tensorflow.cast(train_y, tensorflow.int16), tensorflow.cast(test_y, tensorflow.int16)
# print("---檢視單個資料歸一後的資料")
# print(train_x[0][6])  # 30/255=0.11764706  ---歸一化每個值除以255
# print(train_y[0])

# Step2: 配置網絡 建立模型
'''
以下的代碼判斷就是定義一個簡單的多層感覺器,一共有三層,
兩個大小為100的隐層和一個大小為10的輸出層,因為MNIST資料集是手寫0到9的灰階圖像,
類别有10個,是以最後的輸出大小是10。最後輸出層的激活函數是Softmax,
是以最後的輸出層相當于一個分類器。加上一個輸入層的話,
多層感覺器的結構是:輸入層-->>隐層-->>隐層-->>輸出層。
激活函數 https://zhuanlan.zhihu.com/p/337902763
'''
# 構造模型
# model = keras.Sequential([
#     # 在第一層的網絡中,我們的輸入形狀是28*28,這裡的形狀就是圖檔的長度和寬度。
#     keras.layers.Flatten(input_shape=(28, 28)),
#     # 是以神經網絡有點像濾波器(過濾裝置),輸入一組28*28像素的圖檔後,輸出10個類别的判斷結果。那這個128的數字是做什麼用的呢?
#     # 我們可以這樣想象,神經網絡中有128個函數,每個函數都有自己的參數。
#     # 我們給這些函數進行一個編号,f0,f1…f127 ,我們想的是當圖檔的像素一一帶入這128個函數後,這些函數的組合最終輸出一個标簽值,在這個樣例中,我們希望它輸出09 。
#     # 為了得到這個結果,計算機必須要搞清楚這128個函數的具體參數,之後才能計算各個圖檔的标簽。這裡的邏輯是,一旦計算機搞清楚了這些參數,那它就能夠認出不同的10個類别的事物了。
#     keras.layers.Dense(100, activation=tensorflow.nn.relu),
#     # 最後一層是10,是資料集中各種類别的代号,資料集總共有10類,這裡就是10 。
#     keras.layers.Dense(10, activation=tensorflow.nn.softmax)
# ])

model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))  # 添加Flatten層說明輸入資料的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))  # 添加隐含層,為全連接配接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))  # 添加輸出層,為全連接配接層,10個節點,softmax激活函數
print("列印模型結構")
# 使用 summary 列印模型結構
# print(model.summary())
print('\n', model.summary())  # 檢視網絡結構和參數資訊

'''
接着是配置模型,在這一步,我們需要指定模型訓練時所使用的優化算法與損失函數,
此外,這裡我們也可以定義計算精度相關的API。
優化器https://zhuanlan.zhihu.com/p/27449596
'''
# 配置模型  配置模型訓練方法
# 設定神經網絡的優化器和損失函數。# 使用Adam算法進行優化   # 使用CrossEntropyLoss 計算損失 # 使用Accuracy 計算精度
# model.compile(optimizer=tensorflow.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# adam算法參數采用keras預設的公開參數,損失函數采用稀疏交叉熵損失函數,準确率采用稀疏分類準确率函數
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

# Step3:模型訓練
# 開始模型訓練
# model.fit(x_train,  # 設定訓練資料集
#           y_train,
#           epochs=5,  # 設定訓練輪數
#           batch_size=64,  # 設定 batch_size
#           verbose=1)  # 設定日志列印格式
# 批量訓練大小為64,疊代5次,測試集比例0.2(48000條訓練集資料,12000條測試集資料)
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)

# STEP4: 模型評估
# 評估模型,不輸出預測結果輸出損失和精确度. test_loss損失,test_acc精确度
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2)  # 每次疊代輸出一條記錄,來評價該模型是否有比較好的泛化能力
# model.evaluate(test_dataset, verbose=1)
print('Test 損失: %.3f' % test_loss)
print('Test 精确度: %.3f' % test_acc)
# 結果可視化
print(history.history)
loss = history.history['loss']  # 訓練集損失
val_loss = history.history['val_loss']  # 測試集損失
acc = history.history['sparse_categorical_accuracy']  # 訓練集準确率
val_acc = history.history['val_sparse_categorical_accuracy']  # 測試集準确率

plt.figure(figsize=(10, 3))
plt.subplot(121)
plt.plot(loss, color='b', label='train')
plt.plot(val_loss, color='r', label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc, color='b', label='train')
plt.plot(val_acc, color='r', label='test')
plt.ylabel('Accuracy')
plt.legend()

# 暫停5秒關閉畫布,否則畫布一直打開的同時,會持續占用GPU記憶體
# plt.ion()  # 打開互動式操作模式
# plt.show()
# plt.pause(5)
# plt.close()
# plt.show()

# Step5:模型預測 輸入測試資料,輸出預測結果
for i in range(1):
    num = np.random.randint(1, 10000)  # 在1~10000之間生成随機整數
    plt.subplot(2, 5, i + 1)
    plt.axis('off')
    plt.imshow(test_x[num], cmap='gray')
    demo = tensorflow.reshape(test_x[num], (1, 28, 28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标簽值:' + str(test_y[num]) + '\n預測值:' + str(y_pred))
# plt.show()

'''
儲存模型
訓練好的模型可以用于加載後對新輸入資料進行預測,是以需要先進行儲存已訓練模型
'''
#使用save_model儲存完整模型
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')           

mnistPredictDemo.py

import numpy as np
import tensorflow as tensorflow
import gzip
import os.path
from tensorflow.keras.models import load_model
# 預測
def predict(test_x, test_y):
    test_x, test_y = test_x, test_y
    '''
    五、模型評估
    需要先加載已訓練模型,然後用其預測新的資料,計算評估名額
    '''
    # 模型加載
    # 加載.pb模型檔案
    global load_model
    # load_model = load_model('./saved_model')
    load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
    load_model.summary()
    # make a prediction
    print("test_x")
    print(test_x)
    print(test_x.ndim)
    print(test_x.shape)

    demo = tensorflow.reshape(test_x, (1, 28, 28))
    input_data = np.array(demo)  # 準備你的輸入資料
    input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
    # test_x = tensorflow.cast(test_x / 255.0, tensorflow.float32)
    # test_y = tensorflow.cast(test_y, tensorflow.int16)
    predictValue = load_model.predict(input_data)
    print("predictValue")
    print(predictValue)
    y_pred = np.argmax(predictValue)
    print('标簽值:' + str(test_y) + '\n預測值:' + str(y_pred))
    return y_pred, test_y,
  
def load_data(data_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
    return (train_x, train_y), (test_x, test_y)

(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print(train_x[0])
predict(train_x[0], train_y)

           

四、完整的java代碼

tensorflow 需要的java 版本對應表: https://github.com/tensorflow/java/#tensorflow-version-support

本工程使用環境為

jdk版本:openjdk-21

pom依賴如下:

<dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.6.0-SNAPSHOT</version>
        </dependency>

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-framework</artifactId>
            <version>0.6.0-SNAPSHOT</version>
        </dependency>
    </dependencies>

    <repositories>
        <repository>
            <id>tensorflow-snapshots</id>
            <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
            <snapshots>
                <enabled>true</enabled>
            </snapshots>
        </repository>
    </repositories>           

資料集建立和解析類

MnistDataset.class

package org.example.tensorDemo.datasets.mnist;

import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.ImageBatchIterator;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;

import java.io.DataInputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;

import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;



public class MnistDataset {
    public static final int NUM_CLASSES = 10;

    private static final int TYPE_UBYTE = 0x08;

    /**
     * 訓練圖檔位元組類型的多元數組
     */
    private final ByteNdArray trainingImages;

    /**
     * 訓練标簽位元組類型的多元數組
     */
    private final ByteNdArray trainingLabels;

    /**
     * 驗證圖檔位元組類型的多元數組
     */
    public final ByteNdArray validationImages;

    /**
     * 驗證标簽位元組類型的多元數組
     */
    public final ByteNdArray validationLabels;

    /**
     * 測試圖檔位元組類型的多元數組
     */
    private final ByteNdArray testImages;

    /**
     * 測試标簽位元組類型的多元數組
     */
    private final ByteNdArray testLabels;

    /**
     * 圖檔的大小
     */
    private final long imageSize;


    /**
     * Mnist 資料集構造器
     */
    private MnistDataset(
            ByteNdArray trainingImages,
            ByteNdArray trainingLabels,
            ByteNdArray validationImages,
            ByteNdArray validationLabels,
            ByteNdArray testImages,
            ByteNdArray testLabels
    ) {
        this.trainingImages = trainingImages;
        this.trainingLabels = trainingLabels;
        this.validationImages = validationImages;
        this.validationLabels = validationLabels;
        this.testImages = testImages;
        this.testLabels = testLabels;
        //第一個圖像的形狀,并傳回其尺寸大小。每一張圖檔包含28X28個像素點 是以應該為784
        this.imageSize = trainingImages.get(0).shape().size();
//        System.out.println("imageSize = " + imageSize);


//        System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
//        System.out.println("資料集的次元:" + trainingImages.rank());
//        System.out.println("資料集的形狀 = " + trainingImages.shape());
//        System.out.println("資料集的大小 = " + trainingImages.shape().get(0));
//        System.out.println("資料集 = ");
//        for (int i = 0; i < trainingImages.shape().get(0); i++) {
//            for (int j = 0; j < trainingImages.shape().get(1); j++) {
//                for (int k = 0; k < trainingImages.shape().get(2); k++) {
//                    System.out.print(trainingImages.getObject(i, j, k) + " ");
//                }
//                System.out.println();
//            }
//            System.out.println();
//        }
//        System.out.println("檢視單個資料 = " + trainingImages.get(0));
//        for (int j = 0; j < trainingImages.shape().get(1); j++) {
//            for (int k = 0; k < trainingImages.shape().get(2); k++) {
//                System.out.print(trainingImages.getObject(0, j, k) + " ");
//            }
//            System.out.println();
//        }
//        System.out.println("檢視單個資料大小 = " + trainingImages.get(0).size());
//        System.out.println("檢視trainingImages三維數組下的第一個元素的第二個二維數組大小 = " + trainingImages.get(0).get(1).size());
//        System.out.println("檢視trainingImages三維數組下的第一個元素的第7個二維數組的第8個元素 = " + trainingImages.getObject(0, 6, 8));
//        System.out.println("trainingLabels = " + trainingLabels.getObject(1));
    }

    /**
     * @param validationSize        驗證的資料
     * @param trainingImagesArchive 訓練圖檔路徑
     * @param trainingLabelsArchive 訓練标簽路徑
     * @param testImagesArchive     測試圖檔路徑
     * @param testLabelsArchive     測試标簽路徑
     */
    public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive,
                                      String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);

            if (validationSize > 0) {
                return new MnistDataset(
                        trainingImages.slice(sliceFrom(validationSize)),
                        trainingLabels.slice(sliceFrom(validationSize)),
                        trainingImages.slice(sliceTo(validationSize)),
                        trainingLabels.slice(sliceTo(validationSize)),
                        testImages,
                        testLabels
                );
            }
            return new MnistDataset(trainingImages, trainingLabels, null, null, testImages, testLabels);

        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }

    /**
     * @param trainingImagesArchive 訓練圖檔路徑
     * @param trainingLabelsArchive 訓練标簽路徑
     * @param testImagesArchive     測試圖檔路徑
     * @param testLabelsArchive     測試标簽路徑
     */
    public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,
                                                     String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);
            trainingImages.slice(sliceFrom(0));
            trainingLabels.slice(sliceTo(0));
            // 切片操作
            Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
            ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
            ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作


            if (index >= 0) {
                return new MnistDataset(
                        trainingImages,
                        trainingLabels,
                        validationImage,
                        validationLable,
                        testImages,
                        testLabels
                );
            } else {
                return null;
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }

    private static ByteNdArray readArchive(String archiveName) throws IOException {
        System.out.println("archiveName = " + archiveName);
        DataInputStream archiveStream = new DataInputStream(
                //new GZIPInputStream(new java.io.FileInputStream("src/main/resources/"+archiveName))
                new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
        );
        //todo 不知道怎麼讀取和實際的内部結構
        archiveStream.readShort(); // first two bytes are always 0
        byte magic = archiveStream.readByte();
        if (magic != TYPE_UBYTE) {
            throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
        }
        int numDims = archiveStream.readByte();
        long[] dimSizes = new long[numDims];
        int size = 1;  // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
        for (int i = 0; i < dimSizes.length; ++i) {
            dimSizes[i] = archiveStream.readInt();
            size *= dimSizes[i];
        }
        byte[] bytes = new byte[size];
        archiveStream.readFully(bytes);
        return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
    }

    public Iterable<ImageBatch> trainingBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, trainingImages, trainingLabels);
    }

    public Iterable<ImageBatch> validationBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels);
    }

    public Iterable<ImageBatch> testBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, testImages, testLabels);
    }

    public ImageBatch testBatch() {
        return new ImageBatch(testImages, testLabels);
    }

    public long imageSize() {
        return imageSize;
    }

    public long numTrainingExamples() {
        return trainingLabels.shape().size(0);
    }

    public long numTestingExamples() {
        return testLabels.shape().size(0);
    }

    public long numValidationExamples() {
        return validationLabels.shape().size(0);
    }

    public ByteNdArray getValidationImages() {
        return validationImages;
    }

    public ByteNdArray getValidationLabels() {
        return validationLabels;
    }
}           

SimpleMnist.class

package org.example.tensorDemo.dense;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.mnist.MnistDataset;
import org.tensorflow.*;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.nn.Softmax;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;

public class SimpleMnist implements Runnable {
    private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
    private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
    private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
    private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";

    public static void main(String[] args) {
        //加載資料集
//        MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
//                TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
        MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
                TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
        //建立了一個名為graph的圖形對象。
        try (Graph graph = new Graph()) {
            SimpleMnist mnist = new SimpleMnist(graph, validationDataset);
            mnist.run();//建構和訓練模型
            mnist.loadModel("D:\\ai\\ai-demo");
        }
    }

    @Override
    public void run() {
        Ops tf = Ops.create(graph);
        // Create placeholders and variables, which should fit batches of an unknown number of images
        //建立占位符和變量,這些占位符和變量應适合未知數量的圖像批次
        Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
        Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);

        // Create weights with an initial value of 0
        // 建立初始值為 0 的權重
        Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
        Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));

        // Create biases with an initial value of 0
        //建立初始值為 0 的偏置
        Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
        Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));

        // Predict the class of each image in the batch and compute the loss
        //使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,并加上偏置項 biases。
        //wx+b
        MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
        Add<TFloat32> add = tf.math.add(matMul, biases);

        //Softmax 是一個常用的激活函數,它将輸入轉換為表示機率分布的輸出。對于輸入向量中的每個元素,Softmax 函數會計算指數,
        //并對所有元素求和,然後将每個元素的指數除以總和,最終得到一個機率分布。這通常用于多分類問題,以輸出每個類别的機率
        //激活函數 
        Softmax<TFloat32> softmax = tf.nn.softmax(add);

        // 建立一個計算交叉熵的Mean對象
        //損失函數
        Mean<TFloat32> crossEntropy =
                tf.math.mean(  // 計算張量的平均值
                        tf.math.neg(  // 計算張量的負值
                                tf.reduceSum(  // 計算張量的和
                                        tf.math.mul(labels, tf.math.log(softmax)),  //計算标簽和softmax預測的對數乘積
                                        tf.array(1)  // 在指定軸上求和
                                )
                        ),
                        tf.array(0)  // 在指定軸上求平均值
                );

        // Back-propagate gradients to variables for training
        //使用梯度下降優化器來最小化交叉熵損失函數。首先,建立了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
        //梯度下降 https://www.cnblogs.com/guoyaohua/p/8542554.html
        Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
        Op minimize = optimizer.minimize(crossEntropy);

        // Compute the accuracy of the model
        //使用 argMax 函數找出在給定軸上張量中最大值的索引,
        Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
        Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
        //使用 equal 函數比較模型預測的标簽和實際标簽是否相等,再用 cast 函數将布爾值轉換為浮點數,最後使用 mean 函數計算準确率。
        Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));

        // Run the graph
        try (Session session = new Session(graph)) {
            // Train the model
            for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
                try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
                     TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
                    System.out.println("batchImages = " + batchImages.shape());
                    System.out.println("batchLabels = " + batchLabels.shape());
                    // 建立會話運作器
                    session.runner()
                            // 添加要最小化的目标
                            .addTarget(minimize)
                            // 通過feed方法将圖像資料輸入到模型中
                            .feed(images.asOutput(), batchImages)
                            // 通過feed方法将标簽資料輸入到模型中
                            .feed(labels.asOutput(), batchLabels)
                            // 運作會話
                            .run();
                }
            }

            // Test the model
            ImageBatch testBatch = dataset.testBatch();
            try (TFloat32 testImages = preprocessImages(testBatch.images());
                 TFloat32 testLabels = preprocessLabels(testBatch.labels());
                 // 定義一個TFloat32類型的變量accuracyValue,用于存儲計算得到的準确率值
                 TFloat32 accuracyValue = (TFloat32) session.runner()
                         // 從會話中擷取準确率值
                         .fetch(accuracy)
                         .fetch(predicted)
                         .fetch(expected)
                         // 将images作為輸入,testImages作為資料進行喂養
                         .feed(images.asOutput(), testImages)
                         // 将labels作為輸入,testLabels作為資料進行喂養
                         .feed(labels.asOutput(), testLabels)
                         // 運作會話并擷取結果
                         .run()
                         // 擷取第一個結果并存儲在accuracyValue中
                         .get(0)) {
                System.out.println("Accuracy: " + accuracyValue.getFloat());
            }
            // 儲存模型
            SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
            Signature.Builder builder = Signature.builder();
            builder.input("images", images);
            builder.input("labels", labels);
            builder.output("accuracy", accuracy);
            builder.output("expected", expected);
            builder.output("predicted", predicted);
            Signature signature = builder.build();
            SessionFunction sessionFunction = SessionFunction.create(signature, session);
            exporter.withFunction(sessionFunction);
            exporter.export();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

    private static final int VALIDATION_SIZE = 5;
    private static final int TRAINING_BATCH_SIZE = 100;
    private static final float LEARNING_RATE = 0.2f;

    private static TFloat32 preprocessImages(ByteNdArray rawImages) {
        Ops tf = Ops.create();
        // Flatten images in a single dimension and normalize their pixels as floats.
        long imageSize = rawImages.get(0).shape().size();
        return tf.math.div(
                tf.reshape(
                        tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
                        tf.array(-1L, imageSize)
                ),
                tf.constant(255.0f)
        ).asTensor();
    }

    private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
        Ops tf = Ops.create();
        // Map labels to one hot vectors where only the expected predictions as a value of 1.0
        return tf.oneHot(
                tf.constant(rawLabels),
                tf.constant(MnistDataset.NUM_CLASSES),
                tf.constant(1.0f),
                tf.constant(0.0f)
        ).asTensor();
    }

    private final Graph graph;
    private final MnistDataset dataset;

    private SimpleMnist(Graph graph, MnistDataset dataset) {
        this.graph = graph;
        this.dataset = dataset;
    }

    public void loadModel(String exportDir) {
        // load saved model
        SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
        try {
            printSignature(model);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        ByteNdArray validationImages = dataset.getValidationImages();
        ByteNdArray validationLabels = dataset.getValidationLabels();
        TFloat32 testImages = preprocessImages(validationImages);
        System.out.println("testImages = " + testImages.shape());
        TFloat32 testLabels = preprocessLabels(validationLabels);
        System.out.println("testLabels = " + testLabels.shape());
        Result run = model.session().runner()
                .feed("Placeholder:0", testImages)
                .feed("Placeholder_1:0", testLabels)
                .fetch("ArgMax:0")
                .fetch("ArgMax_1:0")
                .fetch("Mean_1:0")
                .run();
        // 處理輸出
        Optional<Tensor> tensor1 = run.get("ArgMax:0");
        Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
        Optional<Tensor> tensor3 = run.get("Mean_1:0");
        TInt64 predicted = (TInt64) tensor1.get();
        Long predictedValue = predicted.getObject(0);
        System.out.println("predictedValue = " + predictedValue);
        TInt64 expected = (TInt64) tensor2.get();
        Long expectedValue = expected.getObject(0);
        System.out.println("expectedValue = " + expectedValue);
        TFloat32 accuracy = (TFloat32) tensor3.get();
        System.out.println("accuracy = " + accuracy.getFloat());
    }

    private static void printSignature(SavedModelBundle model) throws Exception {
        MetaGraphDef m = model.metaGraphDef();
        SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
        int numInputs = sig.getInputsCount();
        int i = 1;
        System.out.println("MODEL SIGNATURE");
        System.out.println("Inputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
        }
        int numOutputs = sig.getOutputsCount();
        i = 1;
        System.out.println("Outputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
        }
        System.out.println("-----------------------------------------------");
    }
}           

五、最後兩套代碼運作結果

AI入門之手寫數字識别模型java方式Dense全連接配接神經網絡實作
AI入門之手寫數字識别模型java方式Dense全連接配接神經網絡實作

六、待完善點

1、這裡并沒有對提供web服務輸入圖檔以及圖檔資料二值話等進行處理。有興趣的小夥伴可以自己進行嘗試

2、并沒有使用卷積神經網絡等,隻是用了wx+b和激活函數進行跳躍,以及階梯下降算法和交叉熵

3、沒有進行更多層級的設計等

繼續閱讀