天天看点

多标签分类,跑出合理结果的loss设计,以及每个label都有权重

首先发现

​​

​tf.losses.sigmoid_cross_entropy​

​​

​tf.losses.softmax_cross_entropy​

​​

​tf.keras.losses.CategoricalCrossentropy​

​ 都是不行的

kl_compute = tf.keras.losses.KLDivergence(
    reduction=losses_utils.ReductionV2.NONE,
    name='KL_divergence')
                        
# logits 是 [batch_size, class_number]
# gt_labels 是 [batch_size, class_number], 里面有多个1,其余为0
# gt_label_weights 是 [batch_size, class_number], 里面有上面每个1的权重,其余为0            
loss = tf.reduce_mean(kl_compute(gt_labels * tf.sigmoid(gt_label_weights), 
    tf.nn.softmax(logits)))      

继续阅读