import tensorflow as tf
def main():
NUM_CLASSES = 2 # 2分類
labels = [0, 1, 1, 0, 1, 0, 1, 0] # sample label
batch_size = tf.size(labels) # get size of labels : 8
labels = tf.expand_dims(labels, 1) # 增加一個次元
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1) # 生成索引
concated = tf.concat([indices, labels], 1) # 作為拼接
bs_nc = tf.stack([batch_size, NUM_CLASSES]) # 拼接矩陣 [8 2]
onehot_labels = tf.sparse_to_dense(concated, bs_nc, 1.0, 0.0) # 生成one-hot編碼的标簽
with tf.Session() as sess:
onehot_labels = sess.run(onehot_labels)
print(onehot_labels)
'''
tf.sparse_to_dense的解釋 4個參數
第一個參數 矩陣中元素對應的索引和值 例子中對應concated [[0 0],[1 1],[2 1],[3 0],[4 1],[5 0],[6 1],[7 0]]
第二個參數 輸出矩陣的shape 對應bs_nc [8 2]
第三個參數 指定元素值為1.0
第四個參數 其餘元素值為0.0
'''
if __name__ == '__main__':
main()
運作結果
[[1. 0.]
[0. 1.]
[0. 1.]
[1. 0.]
[0. 1.]
[1. 0.]
[0. 1.]
[1. 0.]]