天天看點

tf.map_fn( )的用法

map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
           swap_memory=False, infer_shape=True, name=None)
           

其中 fn 是一個可調用函數,可以使用 lambda 來表示,elems 是需要處理的 tensors, tf 将會從第一維開始展開,進行 map 操作,dtype 表示 fn 函數的輸出類型,如果 fn 傳回的類型和 elems 中的不同,那麼就必須顯示指定為和 fn 傳回類型相同的類型。

tf.map_fn( )的用法

可以看出 map_fn 是一個反複将可調用函數fn應用于 elems 元素序列的一個高階函數。

有很多用處

在處理圖檔時,是(batch_size,height,width,depth),batch_size是一次處理的多少,一個 batch 内同樣對圖檔進行處理,對視訊進行卷積操作時,視訊輸入是(batch_size,frames,height,width,depth),其中多了個 frames 幀數,肯定是不能對視訊進行卷積的,視訊的每個切片産生後,我們同樣是對每一幀進行卷積,是以采用map_fn 函數,對每個切片應用卷積操作,每個batch 之間沒有關聯,可以并行快速的處理。

tf.map_fn(fn=lambda x:tf.nn.conv2d(x,kernel,stride,padding='same'),elems=batch,dtype=tf.float32)