天天看點

TensorFlow strides 參數讨論

更詳細地讨論見 stackoverflow:Tensorflow Strides Argument

卷積神經網絡(CNN)在 TensorFlow 實作時涉及的 tf.nn.con2d(二維卷積)、tf.nn.max_pool(最大池化)、tf.nn.avg_pool(平均池化)等操作都有關于

strides

(步長)的指定,因為無論是卷積操作還是各種類型的池化操作,都是某種形式的滑動視窗(sliding window)處理,這就要求指定從目前視窗移動下一個視窗位置的移動步長。

TensorFlow 文檔關于

strides

的說明如下:

strides: A list of ints that has length >= 4. The stride of the sliding window for each dimension of the input tensor.

首先要求 strides 為長度不小于 4 的整數構成的 list,

strides

參數表示的是滑窗在輸入張量各個次元上的移動步長。

而且一般要求

strides

的參數,

strides[0] = strides[3] = 1

具體什麼含義呢?

一般而言,對于輸入張量(input tensor)有四維資訊:[batch, height, width, channels](分别表示 batch_size, 也即樣本的數目,單個樣本的行數和列數,樣本的頻道數,rgb圖像就是三維的,灰階圖像則是一維),對于一個二維卷積操作而言,其主要作用在

height, width

上。

strides

參數确定了滑動視窗在各個次元上移動的步數。一種常用的經典設定就是要求,

strides[0]=strides[3]=1

  • strides[0] = 1,也即在 batch 次元上的移動為 1,也就是不跳過任何一個樣本,否則當初也不該把它們作為輸入(input)
  • strides[3] = 1,也即在 channels 次元上的移動為 1,也就是不跳過任何一個顔色通道;