天天看點

TensorFlow 多标簽轉化為one-hot

import tensorflow as tf

def main():
    NUM_CLASSES = 2  # 2分類
    labels = [0, 1, 1, 0, 1, 0, 1, 0]  # sample label
    batch_size = tf.size(labels)  # get size of labels : 8
    labels = tf.expand_dims(labels, 1)  # 增加一個次元
    indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)  # 生成索引
    concated = tf.concat([indices, labels], 1)  # 作為拼接
    bs_nc = tf.stack([batch_size, NUM_CLASSES]) # 拼接矩陣 [8 2]
    onehot_labels = tf.sparse_to_dense(concated, bs_nc, 1.0, 0.0)  # 生成one-hot編碼的标簽

    with tf.Session() as sess:
        onehot_labels = sess.run(onehot_labels)
        print(onehot_labels)



    '''
    tf.sparse_to_dense的解釋 4個參數
    第一個參數 矩陣中元素對應的索引和值  例子中對應concated [[0 0],[1 1],[2 1],[3 0],[4 1],[5 0],[6 1],[7 0]]
    第二個參數 輸出矩陣的shape  對應bs_nc [8 2]
    第三個參數 指定元素值為1.0
    第四個參數 其餘元素值為0.0
    '''


if __name__ == '__main__':
    main()
           

運作結果

[[1. 0.]
 [0. 1.]
 [0. 1.]
 [1. 0.]
 [0. 1.]
 [1. 0.]
 [0. 1.]
 [1. 0.]]

           

繼續閱讀