from keras.layers.core import Lambda
keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)
Lambda函數接受兩個參數,第一個是輸入張量對輸出張量的映射函數,第二個是輸入的shape對輸出的shape的映射函數。
參數
- function:要實作的函數,該函數僅接受一個變量,即上一層的輸出
- output_shape:函數應該傳回的值的shape,可以是一個tuple,也可以是一個根據輸入shape計算輸出shape的函數
- mask:
- arguments:可選,字典,用來記錄向函數中傳遞的其他關鍵字參數
keras Lambda自定義層實作資料的切片,Lambda傳參數
1、代碼如下:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils.visualize_util import plot
from keras.layers import Input, Lambda
from keras.models import Model
def slice(x,index):
return x[:,:,index]
a = Input(shape=(4,2))
x1 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:1})(a)
x1 = Reshape((4,1,1))(x1)
x2 = Reshape((4,1,1))(x2)
output = merge([x1,x2],mode=‘concat‘)
model = Model(a, output)
x_test = np.array([[[1,2],[2,3],[3,4],[4,5]]])
print model.predict(x_test)
plot(model, to_file=‘lambda.png‘,show_shapes=True)
2、注意Lambda 是可以進行參數傳遞的,傳遞的方式如下代碼所述:
def slice(x,index):
return x[:,:,index]
如上,index是參數,通過字典将參數傳遞進去.
x1 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={‘index‘:1})(a)