tf.argmax(input,axis)根據axis取值的不同傳回每行或者每列最大值的索引。
舉個例子:
import tensorflow as tf
import numpy as np
A = [[1,3,4,5,6]]
B = [[1,3,4], [2,4,1]]
with tf.Session() as sess:
print(sess.run(tf.argmax(A, 1)))
print(sess.run(tf.argmax(B, 1)))
輸出:
[4]
[2 1]
輸出[4]因為在A中6最大,6的下标是4。
同理在B[0]中4最大,B[1]中也是4最大,其下标分别為2和1。
這裡有一個參數axis可以設定:
axis=0時比較每一列的元素,将每一列最大元素所在的索引記錄下來,最後輸出每一列最大元素所在的索引數組。
axis=1的時候,将每一行最大元素所在的索引記錄下來,最後傳回每一行最大元素所在的索引數組。
這樣說比較亂,看個例子:
test = np.array([
[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]])
np.argmax(test, 0) #輸出:array([3, 3, 1]
np.argmax(test, 1) #輸出:array([2, 2, 0, 0]