前言:授人以魚不如授人以漁.先學會用,在學原理,在學創造,可能一輩子用不到這種能力,但是不能不具備這種能力。這篇文章主要是介紹算法入門Helloword之手寫圖檔識别模型java中如何實作以及部分解釋。目前大家對于人工智能-機器學習-神經網絡的文章都是基于python語言的,對于擅長java的後端小夥伴想要去了解就不是特别友好,是以這裡給大家介紹一下如何在java中實作,打開新世界的大門。以下為本人個人了解如有錯誤歡迎指正
一、目标:使用MNIST資料集訓練手寫數字圖檔識别模型
在實作一個模型的時候我們要準備哪些知識體系:
1.機器學習基礎:包括監督學習、無監督學習、強化學習等基本概念。
2.資料處理與分析:資料清洗、特征工程、資料可視化等。
3.程式設計語言:如Python,用于實作機器學習算法。
4.數學基礎:線性代數、機率統計、微積分等數學知識。
5.機器學習算法:線性回歸、決策樹、神經網絡、支援向量機等算法。
6.深度學習架構:如TensorFlow、PyTorch等,用于建構和訓練深度學習模型。
7.模型評估與優化:交叉驗證、超參數調優、模型評估名額等。
8.實踐經驗:通過實際項目和競賽積累經驗,不斷提升模型學習能力。
這裡的機器學習HelloWorld是手寫圖檔識别用的是TensorFlow架構
主要需要:
1.了解手寫圖檔的資料集,訓練集是什麼樣的資料(60000,28,28) 、訓練集的标簽是什麼樣的(1)
2.了解激活函數的作用
3.正向傳遞和反向傳播的作用以及實作
4.訓練模型和儲存模型
5.加載儲存的模型使用
二、java代碼與python代碼對比分析
因為python代碼解釋網上已經有很多了,這裡不在重複解釋
1.資料集的加載
python中
def load_data(dpata_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim) # 資料集的次元
print(train_x.shape) # 資料集的形狀
print(len(train_x)) # 資料集的大小
print(train_x) # 資料集
print("---檢視單個資料")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
print("---檢視單個資料")
print(train_y[3])
java中
SimpleMnist.class
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
//加載資料
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
MnistDataset.class
/**
* @param trainingImagesArchive 訓練圖檔路徑
* @param trainingLabelsArchive 訓練标簽路徑
* @param testImagesArchive 測試圖檔路徑
* @param testLabelsArchive 測試标簽路徑
*/
public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
trainingImages.slice(sliceFrom(0));
trainingLabels.slice(sliceTo(0));
// 切片操作
Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
if (index >= 0) {
return new MnistDataset(trainingImages,trainingLabels,validationImage,validationLable,testImages,testLabels);
} else {
return null;
}
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static ByteNdArray readArchive(String archiveName) throws IOException {
System.out.println("archiveName = " + archiveName);
DataInputStream archiveStream = new DataInputStream(new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
);
archiveStream.readShort(); // first two bytes are always 0
byte magic = archiveStream.readByte();
if (magic != TYPE_UBYTE) {
throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
}
int numDims = archiveStream.readByte();
long[] dimSizes = new long[numDims];
int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
for (int i = 0; i < dimSizes.length; ++i) {
dimSizes[i] = archiveStream.readInt();
size *= dimSizes[i];
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
}
/**
* Mnist 資料集構造器
*/
private MnistDataset(ByteNdArray trainingImages, ByteNdArray trainingLabels,ByteNdArray validationImages,ByteNdArray validationLabels,ByteNdArray testImages,ByteNdArray testLabels
) {
this.trainingImages = trainingImages;
this.trainingLabels = trainingLabels;
this.validationImages = validationImages;
this.validationLabels = validationLabels;
this.testImages = testImages;
this.testLabels = testLabels;
this.imageSize = trainingImages.get(0).shape().size();
System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
System.out.println("資料集的次元:" + trainingImages.rank());
System.out.println("資料集的形狀 = " + trainingImages.shape());
System.out.println("資料集的大小 = " + trainingImages.shape().get(0));
System.out.println("檢視單個資料 = " + trainingImages.get(0));
}
2.模型建構
python中
model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28))) # 添加Flatten層說明輸入資料的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu')) # 添加隐含層,為全連接配接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax')) # 添加輸出層,為全連接配接層,10個節點,softmax激活函數
print("列印模型結構")
# 使用 summary 列印模型結構
print('\n', model.summary()) # 檢視網絡結構和參數資訊
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
java中
SimpleMnist.class
Ops tf = Ops.create(graph);
// Create placeholders and variables, which should fit batches of an unknown number of images
//建立占位符和變量,這些占位符和變量應适合未知數量的圖像批次
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
// Create weights with an initial value of 0
// 建立初始值為 0 的權重
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
// Create biases with an initial value of 0
//建立初始值為 0 的偏置
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
// Predict the class of each image in the batch and compute the loss
//使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,并加上偏置項 biases。
//wx+b
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
//Softmax 是一個常用的激活函數,它将輸入轉換為表示機率分布的輸出。對于輸入向量中的每個元素,Softmax 函數會計算指數,
//并對所有元素求和,然後将每個元素的指數除以總和,最終得到一個機率分布。這通常用于多分類問題,以輸出每個類别的機率
Softmax<TFloat32> softmax = tf.nn.softmax(add);
// 建立一個計算交叉熵的Mean對象
Mean<TFloat32> crossEntropy =
tf.math.mean( // 計算張量的平均值
tf.math.neg( // 計算張量的負值
tf.reduceSum( // 計算張量的和
tf.math.mul(labels, tf.math.log(softmax)), //計算标簽和softmax預測的對數乘積
tf.array(1) // 在指定軸上求和
)
),
tf.array(0) // 在指定軸上求平均值
);
// Back-propagate gradients to variables for training
//使用梯度下降優化器來最小化交叉熵損失函數。首先,建立了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
3.訓練模型
python中
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
java中
SimpleMnist.class
// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
// 建立會話運作器
session.runner()
// 添加要最小化的目标
.addTarget(minimize)
// 通過feed方法将圖像資料輸入到模型中
.feed(images.asOutput(), batchImages)
// 通過feed方法将标簽資料輸入到模型中
.feed(labels.asOutput(), batchLabels)
// 運作會話
.run();
}
}
4.模型評估
python中
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2) # 每次疊代輸出一條記錄,來評價該模型是否有比較好的泛化能力
print('Test 損失: %.3f' % test_loss)
print('Test 精确度: %.3f' % test_acc)
java中
SimpleMnist.class
// Test the model
ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
// 定義一個TFloat32類型的變量accuracyValue,用于存儲計算得到的準确率值
TFloat32 accuracyValue = (TFloat32) session.runner()
// 從會話中擷取準确率值
.fetch(accuracy)
.fetch(predicted)
.fetch(expected)
// 将images作為輸入,testImages作為資料進行喂養
.feed(images.asOutput(), testImages)
// 将labels作為輸入,testLabels作為資料進行喂養
.feed(labels.asOutput(), testLabels)
// 運作會話并擷取結果
.run()
// 擷取第一個結果并存儲在accuracyValue中
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
5.儲存模型
python中
# 使用save_model儲存完整模型
# save_model(model, '/media/cfs/使用者ERP名稱/ea/saved_model', save_format='pb')
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')
java中
SimpleMnist.class
// 儲存模型
SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
6.加載模型
python中
# 加載.pb模型檔案
global load_model
load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
load_model.summary()
demo = tensorflow.reshape(test_x, (1, 28, 28))
input_data = np.array(demo) # 準備你的輸入資料
input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
predictValue = load_model.predict(input_data)
print("predictValue")
print(predictValue)
y_pred = np.argmax(predictValue)
print('标簽值:' + str(test_y) + '\n預測值:' + str(y_pred))
return y_pred, test_y,
java中
SimpleMnist.class
//加載模型并預測
public void loadModel(String exportDir) {
// load saved model
SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
try {
printSignature(model);
} catch (Exception e) {
throw new RuntimeException(e);
}
ByteNdArray validationImages = dataset.getValidationImages();
ByteNdArray validationLabels = dataset.getValidationLabels();
TFloat32 testImages = preprocessImages(validationImages);
System.out.println("testImages = " + testImages.shape());
TFloat32 testLabels = preprocessLabels(validationLabels);
System.out.println("testLabels = " + testLabels.shape());
Result run = model.session().runner()
.feed("Placeholder:0", testImages)
.feed("Placeholder_1:0", testLabels)
.fetch("ArgMax:0")
.fetch("ArgMax_1:0")
.fetch("Mean_1:0")
.run();
// 處理輸出
Optional<Tensor> tensor1 = run.get("ArgMax:0");
Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
Optional<Tensor> tensor3 = run.get("Mean_1:0");
TInt64 predicted = (TInt64) tensor1.get();
Long predictedValue = predicted.getObject(0);
System.out.println("predictedValue = " + predictedValue);
TInt64 expected = (TInt64) tensor2.get();
Long expectedValue = expected.getObject(0);
System.out.println("expectedValue = " + expectedValue);
TFloat32 accuracy = (TFloat32) tensor3.get();
System.out.println("accuracy = " + accuracy.getFloat());
}
//列印模型資訊
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = model.metaGraphDef();
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
}
三、完整的python代碼
本工程使用環境為
Python: 3.7.9
https://www.python.org/downloads/windows/
Anaconda: Python 3.11 Anaconda3-2023.09-0-Windows-x86_64
https://www.anaconda.com/download#downloads
tensorflow:2.0.0
直接從anaconda下安裝
mnistTrainDemo.py
import gzip
import os.path
import tensorflow as tensorflow
from tensorflow import keras
# 可視化 image
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import save_model
# 加載資料
# mnist = keras.datasets.mnist
# mnistData = mnist.load_data() #Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- unknown url type: https
"""
這裡可以直接使用
mnist = keras.datasets.mnist
mnistData = mnist.load_data() 加載資料,但是有的時候不成功,是以使用本地加載資料
"""
def load_data(data_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim) # 資料集的次元
print(train_x.shape) # 資料集的形狀
print(len(train_x)) # 資料集的大小
print(train_x) # 資料集
print("---檢視單個資料")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
# 可視化image圖檔、一副image的資料
# plt.imshow(train_x[0].reshape(28, 28), cmap="binary")
# plt.show()
print("---檢視單個資料")
print(train_y[0])
# 資料預處理
# 歸一化、并轉換為tensor張量,資料類型為float32. ---歸一化也可能造成識别率低
# train_x, test_x = tensorflow.cast(train_x / 255.0, tensorflow.float32), tensorflow.cast(test_x / 255.0,
# tensorflow.float32),
# train_y, test_y = tensorflow.cast(train_y, tensorflow.int16), tensorflow.cast(test_y, tensorflow.int16)
# print("---檢視單個資料歸一後的資料")
# print(train_x[0][6]) # 30/255=0.11764706 ---歸一化每個值除以255
# print(train_y[0])
# Step2: 配置網絡 建立模型
'''
以下的代碼判斷就是定義一個簡單的多層感覺器,一共有三層,
兩個大小為100的隐層和一個大小為10的輸出層,因為MNIST資料集是手寫0到9的灰階圖像,
類别有10個,是以最後的輸出大小是10。最後輸出層的激活函數是Softmax,
是以最後的輸出層相當于一個分類器。加上一個輸入層的話,
多層感覺器的結構是:輸入層-->>隐層-->>隐層-->>輸出層。
激活函數 https://zhuanlan.zhihu.com/p/337902763
'''
# 構造模型
# model = keras.Sequential([
# # 在第一層的網絡中,我們的輸入形狀是28*28,這裡的形狀就是圖檔的長度和寬度。
# keras.layers.Flatten(input_shape=(28, 28)),
# # 是以神經網絡有點像濾波器(過濾裝置),輸入一組28*28像素的圖檔後,輸出10個類别的判斷結果。那這個128的數字是做什麼用的呢?
# # 我們可以這樣想象,神經網絡中有128個函數,每個函數都有自己的參數。
# # 我們給這些函數進行一個編号,f0,f1…f127 ,我們想的是當圖檔的像素一一帶入這128個函數後,這些函數的組合最終輸出一個标簽值,在這個樣例中,我們希望它輸出09 。
# # 為了得到這個結果,計算機必須要搞清楚這128個函數的具體參數,之後才能計算各個圖檔的标簽。這裡的邏輯是,一旦計算機搞清楚了這些參數,那它就能夠認出不同的10個類别的事物了。
# keras.layers.Dense(100, activation=tensorflow.nn.relu),
# # 最後一層是10,是資料集中各種類别的代号,資料集總共有10類,這裡就是10 。
# keras.layers.Dense(10, activation=tensorflow.nn.softmax)
# ])
model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28))) # 添加Flatten層說明輸入資料的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu')) # 添加隐含層,為全連接配接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax')) # 添加輸出層,為全連接配接層,10個節點,softmax激活函數
print("列印模型結構")
# 使用 summary 列印模型結構
# print(model.summary())
print('\n', model.summary()) # 檢視網絡結構和參數資訊
'''
接着是配置模型,在這一步,我們需要指定模型訓練時所使用的優化算法與損失函數,
此外,這裡我們也可以定義計算精度相關的API。
優化器https://zhuanlan.zhihu.com/p/27449596
'''
# 配置模型 配置模型訓練方法
# 設定神經網絡的優化器和損失函數。# 使用Adam算法進行優化 # 使用CrossEntropyLoss 計算損失 # 使用Accuracy 計算精度
# model.compile(optimizer=tensorflow.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# adam算法參數采用keras預設的公開參數,損失函數采用稀疏交叉熵損失函數,準确率采用稀疏分類準确率函數
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
# Step3:模型訓練
# 開始模型訓練
# model.fit(x_train, # 設定訓練資料集
# y_train,
# epochs=5, # 設定訓練輪數
# batch_size=64, # 設定 batch_size
# verbose=1) # 設定日志列印格式
# 批量訓練大小為64,疊代5次,測試集比例0.2(48000條訓練集資料,12000條測試集資料)
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)
# STEP4: 模型評估
# 評估模型,不輸出預測結果輸出損失和精确度. test_loss損失,test_acc精确度
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2) # 每次疊代輸出一條記錄,來評價該模型是否有比較好的泛化能力
# model.evaluate(test_dataset, verbose=1)
print('Test 損失: %.3f' % test_loss)
print('Test 精确度: %.3f' % test_acc)
# 結果可視化
print(history.history)
loss = history.history['loss'] # 訓練集損失
val_loss = history.history['val_loss'] # 測試集損失
acc = history.history['sparse_categorical_accuracy'] # 訓練集準确率
val_acc = history.history['val_sparse_categorical_accuracy'] # 測試集準确率
plt.figure(figsize=(10, 3))
plt.subplot(121)
plt.plot(loss, color='b', label='train')
plt.plot(val_loss, color='r', label='test')
plt.ylabel('loss')
plt.legend()
plt.subplot(122)
plt.plot(acc, color='b', label='train')
plt.plot(val_acc, color='r', label='test')
plt.ylabel('Accuracy')
plt.legend()
# 暫停5秒關閉畫布,否則畫布一直打開的同時,會持續占用GPU記憶體
# plt.ion() # 打開互動式操作模式
# plt.show()
# plt.pause(5)
# plt.close()
# plt.show()
# Step5:模型預測 輸入測試資料,輸出預測結果
for i in range(1):
num = np.random.randint(1, 10000) # 在1~10000之間生成随機整數
plt.subplot(2, 5, i + 1)
plt.axis('off')
plt.imshow(test_x[num], cmap='gray')
demo = tensorflow.reshape(test_x[num], (1, 28, 28))
y_pred = np.argmax(model.predict(demo))
plt.title('标簽值:' + str(test_y[num]) + '\n預測值:' + str(y_pred))
# plt.show()
'''
儲存模型
訓練好的模型可以用于加載後對新輸入資料進行預測,是以需要先進行儲存已訓練模型
'''
#使用save_model儲存完整模型
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')
mnistPredictDemo.py
import numpy as np
import tensorflow as tensorflow
import gzip
import os.path
from tensorflow.keras.models import load_model
# 預測
def predict(test_x, test_y):
test_x, test_y = test_x, test_y
'''
五、模型評估
需要先加載已訓練模型,然後用其預測新的資料,計算評估名額
'''
# 模型加載
# 加載.pb模型檔案
global load_model
# load_model = load_model('./saved_model')
load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
load_model.summary()
# make a prediction
print("test_x")
print(test_x)
print(test_x.ndim)
print(test_x.shape)
demo = tensorflow.reshape(test_x, (1, 28, 28))
input_data = np.array(demo) # 準備你的輸入資料
input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
# test_x = tensorflow.cast(test_x / 255.0, tensorflow.float32)
# test_y = tensorflow.cast(test_y, tensorflow.int16)
predictValue = load_model.predict(input_data)
print("predictValue")
print(predictValue)
y_pred = np.argmax(predictValue)
print('标簽值:' + str(test_y) + '\n預測值:' + str(y_pred))
return y_pred, test_y,
def load_data(data_folder):
files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
paths = []
for fname in files:
paths.append(os.path.join(data_folder, fname))
with gzip.open(paths[0], 'rb') as lbpath:
train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print(train_x[0])
predict(train_x[0], train_y)
四、完整的java代碼
tensorflow 需要的java 版本對應表: https://github.com/tensorflow/java/#tensorflow-version-support
本工程使用環境為
jdk版本:openjdk-21
pom依賴如下:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-framework</artifactId>
<version>0.6.0-SNAPSHOT</version>
</dependency>
</dependencies>
<repositories>
<repository>
<id>tensorflow-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
資料集建立和解析類
MnistDataset.class
package org.example.tensorDemo.datasets.mnist;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.ImageBatchIterator;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;
public class MnistDataset {
public static final int NUM_CLASSES = 10;
private static final int TYPE_UBYTE = 0x08;
/**
* 訓練圖檔位元組類型的多元數組
*/
private final ByteNdArray trainingImages;
/**
* 訓練标簽位元組類型的多元數組
*/
private final ByteNdArray trainingLabels;
/**
* 驗證圖檔位元組類型的多元數組
*/
public final ByteNdArray validationImages;
/**
* 驗證标簽位元組類型的多元數組
*/
public final ByteNdArray validationLabels;
/**
* 測試圖檔位元組類型的多元數組
*/
private final ByteNdArray testImages;
/**
* 測試标簽位元組類型的多元數組
*/
private final ByteNdArray testLabels;
/**
* 圖檔的大小
*/
private final long imageSize;
/**
* Mnist 資料集構造器
*/
private MnistDataset(
ByteNdArray trainingImages,
ByteNdArray trainingLabels,
ByteNdArray validationImages,
ByteNdArray validationLabels,
ByteNdArray testImages,
ByteNdArray testLabels
) {
this.trainingImages = trainingImages;
this.trainingLabels = trainingLabels;
this.validationImages = validationImages;
this.validationLabels = validationLabels;
this.testImages = testImages;
this.testLabels = testLabels;
//第一個圖像的形狀,并傳回其尺寸大小。每一張圖檔包含28X28個像素點 是以應該為784
this.imageSize = trainingImages.get(0).shape().size();
// System.out.println("imageSize = " + imageSize);
// System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
// System.out.println("資料集的次元:" + trainingImages.rank());
// System.out.println("資料集的形狀 = " + trainingImages.shape());
// System.out.println("資料集的大小 = " + trainingImages.shape().get(0));
// System.out.println("資料集 = ");
// for (int i = 0; i < trainingImages.shape().get(0); i++) {
// for (int j = 0; j < trainingImages.shape().get(1); j++) {
// for (int k = 0; k < trainingImages.shape().get(2); k++) {
// System.out.print(trainingImages.getObject(i, j, k) + " ");
// }
// System.out.println();
// }
// System.out.println();
// }
// System.out.println("檢視單個資料 = " + trainingImages.get(0));
// for (int j = 0; j < trainingImages.shape().get(1); j++) {
// for (int k = 0; k < trainingImages.shape().get(2); k++) {
// System.out.print(trainingImages.getObject(0, j, k) + " ");
// }
// System.out.println();
// }
// System.out.println("檢視單個資料大小 = " + trainingImages.get(0).size());
// System.out.println("檢視trainingImages三維數組下的第一個元素的第二個二維數組大小 = " + trainingImages.get(0).get(1).size());
// System.out.println("檢視trainingImages三維數組下的第一個元素的第7個二維數組的第8個元素 = " + trainingImages.getObject(0, 6, 8));
// System.out.println("trainingLabels = " + trainingLabels.getObject(1));
}
/**
* @param validationSize 驗證的資料
* @param trainingImagesArchive 訓練圖檔路徑
* @param trainingLabelsArchive 訓練标簽路徑
* @param testImagesArchive 測試圖檔路徑
* @param testLabelsArchive 測試标簽路徑
*/
public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive,
String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
if (validationSize > 0) {
return new MnistDataset(
trainingImages.slice(sliceFrom(validationSize)),
trainingLabels.slice(sliceFrom(validationSize)),
trainingImages.slice(sliceTo(validationSize)),
trainingLabels.slice(sliceTo(validationSize)),
testImages,
testLabels
);
}
return new MnistDataset(trainingImages, trainingLabels, null, null, testImages, testLabels);
} catch (IOException e) {
throw new AssertionError(e);
}
}
/**
* @param trainingImagesArchive 訓練圖檔路徑
* @param trainingLabelsArchive 訓練标簽路徑
* @param testImagesArchive 測試圖檔路徑
* @param testLabelsArchive 測試标簽路徑
*/
public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,
String testImagesArchive, String testLabelsArchive) {
try {
ByteNdArray trainingImages = readArchive(trainingImagesArchive);
ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
ByteNdArray testImages = readArchive(testImagesArchive);
ByteNdArray testLabels = readArchive(testLabelsArchive);
trainingImages.slice(sliceFrom(0));
trainingLabels.slice(sliceTo(0));
// 切片操作
Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
if (index >= 0) {
return new MnistDataset(
trainingImages,
trainingLabels,
validationImage,
validationLable,
testImages,
testLabels
);
} else {
return null;
}
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static ByteNdArray readArchive(String archiveName) throws IOException {
System.out.println("archiveName = " + archiveName);
DataInputStream archiveStream = new DataInputStream(
//new GZIPInputStream(new java.io.FileInputStream("src/main/resources/"+archiveName))
new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
);
//todo 不知道怎麼讀取和實際的内部結構
archiveStream.readShort(); // first two bytes are always 0
byte magic = archiveStream.readByte();
if (magic != TYPE_UBYTE) {
throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
}
int numDims = archiveStream.readByte();
long[] dimSizes = new long[numDims];
int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
for (int i = 0; i < dimSizes.length; ++i) {
dimSizes[i] = archiveStream.readInt();
size *= dimSizes[i];
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
}
public Iterable<ImageBatch> trainingBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, trainingImages, trainingLabels);
}
public Iterable<ImageBatch> validationBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels);
}
public Iterable<ImageBatch> testBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, testImages, testLabels);
}
public ImageBatch testBatch() {
return new ImageBatch(testImages, testLabels);
}
public long imageSize() {
return imageSize;
}
public long numTrainingExamples() {
return trainingLabels.shape().size(0);
}
public long numTestingExamples() {
return testLabels.shape().size(0);
}
public long numValidationExamples() {
return validationLabels.shape().size(0);
}
public ByteNdArray getValidationImages() {
return validationImages;
}
public ByteNdArray getValidationLabels() {
return validationLabels;
}
}
SimpleMnist.class
package org.example.tensorDemo.dense;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.mnist.MnistDataset;
import org.tensorflow.*;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.nn.Softmax;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
public class SimpleMnist implements Runnable {
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
public static void main(String[] args) {
//加載資料集
// MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
// TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
//建立了一個名為graph的圖形對象。
try (Graph graph = new Graph()) {
SimpleMnist mnist = new SimpleMnist(graph, validationDataset);
mnist.run();//建構和訓練模型
mnist.loadModel("D:\\ai\\ai-demo");
}
}
@Override
public void run() {
Ops tf = Ops.create(graph);
// Create placeholders and variables, which should fit batches of an unknown number of images
//建立占位符和變量,這些占位符和變量應适合未知數量的圖像批次
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
// Create weights with an initial value of 0
// 建立初始值為 0 的權重
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
// Create biases with an initial value of 0
//建立初始值為 0 的偏置
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
// Predict the class of each image in the batch and compute the loss
//使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,并加上偏置項 biases。
//wx+b
MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
Add<TFloat32> add = tf.math.add(matMul, biases);
//Softmax 是一個常用的激活函數,它将輸入轉換為表示機率分布的輸出。對于輸入向量中的每個元素,Softmax 函數會計算指數,
//并對所有元素求和,然後将每個元素的指數除以總和,最終得到一個機率分布。這通常用于多分類問題,以輸出每個類别的機率
//激活函數
Softmax<TFloat32> softmax = tf.nn.softmax(add);
// 建立一個計算交叉熵的Mean對象
//損失函數
Mean<TFloat32> crossEntropy =
tf.math.mean( // 計算張量的平均值
tf.math.neg( // 計算張量的負值
tf.reduceSum( // 計算張量的和
tf.math.mul(labels, tf.math.log(softmax)), //計算标簽和softmax預測的對數乘積
tf.array(1) // 在指定軸上求和
)
),
tf.array(0) // 在指定軸上求平均值
);
// Back-propagate gradients to variables for training
//使用梯度下降優化器來最小化交叉熵損失函數。首先,建立了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
//梯度下降 https://www.cnblogs.com/guoyaohua/p/8542554.html
Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
Op minimize = optimizer.minimize(crossEntropy);
// Compute the accuracy of the model
//使用 argMax 函數找出在給定軸上張量中最大值的索引,
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
//使用 equal 函數比較模型預測的标簽和實際标簽是否相等,再用 cast 函數将布爾值轉換為浮點數,最後使用 mean 函數計算準确率。
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
// Run the graph
try (Session session = new Session(graph)) {
// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
System.out.println("batchImages = " + batchImages.shape());
System.out.println("batchLabels = " + batchLabels.shape());
// 建立會話運作器
session.runner()
// 添加要最小化的目标
.addTarget(minimize)
// 通過feed方法将圖像資料輸入到模型中
.feed(images.asOutput(), batchImages)
// 通過feed方法将标簽資料輸入到模型中
.feed(labels.asOutput(), batchLabels)
// 運作會話
.run();
}
}
// Test the model
ImageBatch testBatch = dataset.testBatch();
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
// 定義一個TFloat32類型的變量accuracyValue,用于存儲計算得到的準确率值
TFloat32 accuracyValue = (TFloat32) session.runner()
// 從會話中擷取準确率值
.fetch(accuracy)
.fetch(predicted)
.fetch(expected)
// 将images作為輸入,testImages作為資料進行喂養
.feed(images.asOutput(), testImages)
// 将labels作為輸入,testLabels作為資料進行喂養
.feed(labels.asOutput(), testLabels)
// 運作會話并擷取結果
.run()
// 擷取第一個結果并存儲在accuracyValue中
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
// 儲存模型
SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
Signature.Builder builder = Signature.builder();
builder.input("images", images);
builder.input("labels", labels);
builder.output("accuracy", accuracy);
builder.output("expected", expected);
builder.output("predicted", predicted);
Signature signature = builder.build();
SessionFunction sessionFunction = SessionFunction.create(signature, session);
exporter.withFunction(sessionFunction);
exporter.export();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static final int VALIDATION_SIZE = 5;
private static final int TRAINING_BATCH_SIZE = 100;
private static final float LEARNING_RATE = 0.2f;
private static TFloat32 preprocessImages(ByteNdArray rawImages) {
Ops tf = Ops.create();
// Flatten images in a single dimension and normalize their pixels as floats.
long imageSize = rawImages.get(0).shape().size();
return tf.math.div(
tf.reshape(
tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
tf.array(-1L, imageSize)
),
tf.constant(255.0f)
).asTensor();
}
private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
Ops tf = Ops.create();
// Map labels to one hot vectors where only the expected predictions as a value of 1.0
return tf.oneHot(
tf.constant(rawLabels),
tf.constant(MnistDataset.NUM_CLASSES),
tf.constant(1.0f),
tf.constant(0.0f)
).asTensor();
}
private final Graph graph;
private final MnistDataset dataset;
private SimpleMnist(Graph graph, MnistDataset dataset) {
this.graph = graph;
this.dataset = dataset;
}
public void loadModel(String exportDir) {
// load saved model
SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
try {
printSignature(model);
} catch (Exception e) {
throw new RuntimeException(e);
}
ByteNdArray validationImages = dataset.getValidationImages();
ByteNdArray validationLabels = dataset.getValidationLabels();
TFloat32 testImages = preprocessImages(validationImages);
System.out.println("testImages = " + testImages.shape());
TFloat32 testLabels = preprocessLabels(validationLabels);
System.out.println("testLabels = " + testLabels.shape());
Result run = model.session().runner()
.feed("Placeholder:0", testImages)
.feed("Placeholder_1:0", testLabels)
.fetch("ArgMax:0")
.fetch("ArgMax_1:0")
.fetch("Mean_1:0")
.run();
// 處理輸出
Optional<Tensor> tensor1 = run.get("ArgMax:0");
Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
Optional<Tensor> tensor3 = run.get("Mean_1:0");
TInt64 predicted = (TInt64) tensor1.get();
Long predictedValue = predicted.getObject(0);
System.out.println("predictedValue = " + predictedValue);
TInt64 expected = (TInt64) tensor2.get();
Long expectedValue = expected.getObject(0);
System.out.println("expectedValue = " + expectedValue);
TFloat32 accuracy = (TFloat32) tensor3.get();
System.out.println("accuracy = " + accuracy.getFloat());
}
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = model.metaGraphDef();
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
System.out.println("-----------------------------------------------");
}
}
五、最後兩套代碼運作結果
六、待完善點
1、這裡并沒有對提供web服務輸入圖檔以及圖檔資料二值話等進行處理。有興趣的小夥伴可以自己進行嘗試
2、并沒有使用卷積神經網絡等,隻是用了wx+b和激活函數進行跳躍,以及階梯下降算法和交叉熵
3、沒有進行更多層級的設計等