天天看點

Tensorflow自定義激活函數/函數/梯度Tensorflow自定義激活函數/函數/梯度

Tensorflow自定義激活函數/函數/梯度

(對于Tensorflow 1.x)

最近剛做完一個paper,需要自定義激活函數,記錄一下心得,順便幫助下有需要的小夥伴。大刀闊斧,直接上解決方案:

1、對于分段(激活)函數,代碼分開寫

2、使用自帶自定義梯度

詳解

Tensorflow是自動求導(不懂就百度),是以我們不需要定義梯度,但大家可能會遇到和我一樣的問題(在訓練模型的時候loss爆炸),是以大家才會來查吧。

自定義激活函數/函數直接定義就可以,比如:

output = tf.exp(input)
output = tf.log(input)
           

但為什麼有時候會梯度爆炸?

因為激活函數大多是參照relu進行修改,故大多是分段函數,分段函數在tensorflow中使用

tf.where(tf.greater(input, [0.0]),function1,function2)
           

funtion1計算大于0的數,function2計算小于等于0的數,但這就導緻我構造的激活函數loss爆炸。原因不詳,猜測是先計算所有輸入都參與function1和function2的計算。

我使用了tensorflow定義swish的例子定義函數:

def _swish_shape(op):
  return [op.inputs[0].shape]

@function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True)
def _swish_grad(features, grad):
  sigmoid_features = math_ops.sigmoid(features)
  activation_grad = (
      sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
  return grad * activation_grad

@tf_export("nn.swish")
@function.Defun(
    grad_func=_swish_grad,
    shape_func=_swish_shape,
    func_name="swish",
    noinline=True)
def swish(features):
  features = ops.convert_to_tensor(features, name="features")
  return features * math_ops.sigmoid(features)
           

加入可訓練參數的話,因為涉及tensorflow自帶函數,較為複雜,可以用上述形式,定義多個函數實作。

下面是一個簡單例子

#F2
@function.Defun(shape_func=shape, func_name="F2_grad", noinline=True)
def F2_grad(features, grad):
    return grad * tf.where(tf.greater(features, features*0), features*0+1, tf.exp(features))


@tf_export("F2")
@function.Defun(
    grad_func=F2_grad,
    shape_func=shape,
    func_name="F2",
    noinline=True
)
def F2(features):
    return tf.where(tf.greater(features, features*0), features, tf.exp(features))




def my_best(features, a1, a2, a3):
    f1 = F1(a2*features)
    f2 = F2(a3*features)
    return tf.where(tf.greater(features, features*0), features, a1*(f1-features*f2))
           

繼續閱讀