天天看點

ValueError: Tensor conversion requested dtype int64 for Tensor with dtype float64: ‘Tensor(“loss/a

用keras以TensorFlow作為後端重寫相對熵函數,報錯。。。

def KL(y_true, y_pred):

    weights = K.sum(K.cast(K.argmax(y_true, axis=1)*K.log(K.argmax(y_true, axis=1)/K.argmax(y_pred, axis=1)),dtype='float32'))
    return weights* losses.categorical_crossentropy(y_true, y_pred)
           

報錯:

ValueError: Tensor conversion requested dtype int64 for Tensor with dtype float64: 'Tensor("loss/a
           

原因是因為:

K.log(K.argmax(y_true, axis=1)/K.argmax(y_pred, axis=1))

進行

log

計算時得到的數為

‘float64’

,而

K.argmax(y_true, axis=1)

得到的結果為

int64

,是以将

K.argmax(y_true, axis=1)

改為

K.cast(K.argmax(y_true, axis=1),dtype='float64')

int64

轉變為

‘float64’

正确代碼為:相對熵函數

def KL(y_true, y_pred):

    weights = K.sum(K.cast(K.cast(K.argmax(y_true, axis=1),dtype='float64')*K.log(K.argmax(y_true, axis=1)/K.argmax(y_pred, axis=1)),dtype='float32'))
    return weights* losses.categorical_crossentropy(y_true, y_pred)
           

繼續閱讀