天天看點

Keras Lamba層keras Lambda自定義層實作資料的切片,Lambda傳參數

from keras.layers.core import Lambda
keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)
           

Lambda函數接受兩個參數,第一個是輸入張量對輸出張量的映射函數,第二個是輸入的shape對輸出的shape的映射函數。

參數

  • function:要實作的函數,該函數僅接受一個變量,即上一層的輸出
  • output_shape:函數應該傳回的值的shape,可以是一個tuple,也可以是一個根據輸入shape計算輸出shape的函數
  • mask: 
  • arguments:可選,字典,用來記錄向函數中傳遞的其他關鍵字參數

keras Lambda自定義層實作資料的切片,Lambda傳參數

1、代碼如下:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils.visualize_util import plot
from keras.layers import Input, Lambda
from keras.models import Model
 
def slice(x,index):
  return x[:,:,index]
 
a = Input(shape=(4,2))
x1 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:1})(a)
x1 = Reshape((4,1,1))(x1)
x2 = Reshape((4,1,1))(x2)
output = merge([x1,x2],mode=‘concat‘)
 
model = Model(a, output)
x_test = np.array([[[1,2],[2,3],[3,4],[4,5]]])
print model.predict(x_test)
plot(model, to_file=‘lambda.png‘,show_shapes=True)
           

2、注意Lambda 是可以進行參數傳遞的,傳遞的方式如下代碼所述:

 def slice(x,index):

 return x[:,:,index]
           

 如上,index是參數,通過字典将參數傳遞進去.

x1 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:1})(a)
           

3、上述代碼實作的是,将矩陣的每一列提取出來,然後單獨進行操作,最後在拼在一起。可視化的圖如下所示。

Keras Lamba層keras Lambda自定義層實作資料的切片,Lambda傳參數