天天看點

tensorflow——tf.argmax()和axis詳解

tf.argmax(input,axis)根據axis取值的不同傳回每行或者每列最大值的索引。

舉個例子:

import tensorflow as tf
import numpy as np
 
A = [[1,3,4,5,6]]
B = [[1,3,4], [2,4,1]]
 
with tf.Session() as sess:
    print(sess.run(tf.argmax(A, 1)))
    print(sess.run(tf.argmax(B, 1)))
           

輸出:

[4]

[2 1]

輸出[4]因為在A中6最大,6的下标是4。

同理在B[0]中4最大,B[1]中也是4最大,其下标分别為2和1。

這裡有一個參數axis可以設定:

axis=0時比較每一列的元素,将每一列最大元素所在的索引記錄下來,最後輸出每一列最大元素所在的索引數組。

axis=1的時候,将每一行最大元素所在的索引記錄下來,最後傳回每一行最大元素所在的索引數組。

這樣說比較亂,看個例子:

test = np.array([
[1, 2, 3],
 [2, 3, 4], 
 [5, 4, 3], 
 [8, 7, 2]])
np.argmax(test, 0)   #輸出:array([3, 3, 1]
np.argmax(test, 1)   #輸出:array([2, 2, 0, 0]
           

繼續閱讀