天天看點

theano-scan

Theano的scan解析

前面講過function函數,可以利用它實作梯度更新,重點就是裡面的updates參數,并且還可以根據模型的輸入inputs和givens參數以及updates的shared_variable也就是模型參數得到模型輸出也就是cost。而theano另外一個很重要的函數就是scan,它具有loop的效果。從函數名稱可以看出來這個函數的作用就是掃描,是對提供的參數進行某種掃描操作,也就是不斷的更新。

1. 這是一段截取自RBM示例程式的有關scan的代碼

(
            [
                pre_sigmoid_nvs,
                nv_means,
                nv_samples,
                pre_sigmoid_nhs,
                nh_means,
                nh_samples
            ],
            updates  # is a subclass of dictionary specifying the update rules
                     # for all shared variables used in scan
        ) = theano.scan(
            self.gibbs_hvh, # fn
            # the None are place holders, saying that
            # chain_start is the initial state corresponding to the
            # 6th output
            outputs_info=[None, None, None, None, None, chain_start],
            #chain_start's tap value is [-1]
            n_steps=k,
            name="gibbs_hvh"
        )
           

scan函數中最主要的參數有fn,sequences,outputs_info,non_sequences和n_steps,這段示例代碼中沒有sequences和non_sequences這兩個參數,下面仔細讨論一下這幾個參數的作用。

  • fn: 這是scan函數最重要的一個參數,即掃描函數,它描述了一步(one_step)scan所做的操作,它會根據輸入構造一組輸出,而其餘的參數sequences,outputs_info,non_sequences都可以作為fn的輸入,一次運作也就是fn對這些參數做一次掃描。而則和個scan函數之是以厲害,是因為它不僅僅可以儲存一次scan的結果,前幾次scan的結果它也可以儲存,是以所有這些參數輸入fn的順序為:
    1. 第一個sequence的所有時間片
    2. 第二個sequence的所有時間片
    3. ……
    4. 最後一個sequence的所有時間片
    5. 第一個outputs_info的所有過去的時間片
    6. 第二個outputs_info的所有過去的時間片
    7. ……
    8. 最後一個outputs_info的所有過去的時間片
    9. 所有的non_sequences

      sequences中所有參數的順序是和輸入給fn的順序一樣的,scan的輸出變量的順序是和outputs_info中參數的順序一樣的,看下面一個示例程式:

scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument2])
           

fn會有得到一下順序的輸入:

1.

Sequence1[t-3]

1.

Sequence1[t+2]

1.

Sequence2[t]

1.

Sequence3[t+3]

1.

Output1[t-3]

1.

Output1[t-5]

1.

Output3[t-1]

1.

Argument1

1.

Argument2

  • sequences:sequences是一個清單或者一系列字典,scan就會在這個sequences上面疊代掃描,如果sequences是一系列字典的話,這些字典應該有兩個關鍵字一個是input,另一個是taps,可以參考上面的代碼段,如果不設定tap的話,預設是0,這個參數對于循環神經網絡至關重要。
  • outputs_info: 這是一個清單或者字典,它主要用來在每次step中反複初始化scan的輸出值,當然這個值也一直在變化,可以參考第一個代碼段中,outputs_info是一個六個元素的清單,前5個主要用來占位,第六個用來初始化輸出值并且作為fn的輸入,當下一次scan時,這個list和輸出scan的輸出清單是對應的,每進行一步scan操作,outputs_info中的數值會被上一次疊代的輸出值更新掉
  • non_sequences: 這個參數也是一個存放輸入到fn資料的清單,每一步scan這些資料都會輸入到fn中。
  • n_steps:這個參數表示scan執行的次數。
  • 輸出參數:scan函數主要有兩個輸出一個就是fn函數的輸出,是一個清單的形式,如代碼段1中的gibbs_hvh函數輸出了6個變量,還有一個就是updates,這是一個updates的dictionary(that tells how to update any shared variable after each iteration step)指明了scan中用到的所有shared variables的更新規則。

繼續閱讀