天天看點

多标簽分類,跑出合理結果的loss設計,以及每個label都有權重

首先發現

​​

​tf.losses.sigmoid_cross_entropy​

​​

​tf.losses.softmax_cross_entropy​

​​

​tf.keras.losses.CategoricalCrossentropy​

​ 都是不行的

kl_compute = tf.keras.losses.KLDivergence(
    reduction=losses_utils.ReductionV2.NONE,
    name='KL_divergence')
                        
# logits 是 [batch_size, class_number]
# gt_labels 是 [batch_size, class_number], 裡面有多個1,其餘為0
# gt_label_weights 是 [batch_size, class_number], 裡面有上面每個1的權重,其餘為0            
loss = tf.reduce_mean(kl_compute(gt_labels * tf.sigmoid(gt_label_weights), 
    tf.nn.softmax(logits)))      

繼續閱讀