天天看点

theano-scan

Theano的scan解析

前面讲过function函数,可以利用它实现梯度更新,重点就是里面的updates参数,并且还可以根据模型的输入inputs和givens参数以及updates的shared_variable也就是模型参数得到模型输出也就是cost。而theano另外一个很重要的函数就是scan,它具有loop的效果。从函数名称可以看出来这个函数的作用就是扫描,是对提供的参数进行某种扫描操作,也就是不断的更新。

1. 这是一段截取自RBM示例程序的有关scan的代码

(
            [
                pre_sigmoid_nvs,
                nv_means,
                nv_samples,
                pre_sigmoid_nhs,
                nh_means,
                nh_samples
            ],
            updates  # is a subclass of dictionary specifying the update rules
                     # for all shared variables used in scan
        ) = theano.scan(
            self.gibbs_hvh, # fn
            # the None are place holders, saying that
            # chain_start is the initial state corresponding to the
            # 6th output
            outputs_info=[None, None, None, None, None, chain_start],
            #chain_start's tap value is [-1]
            n_steps=k,
            name="gibbs_hvh"
        )
           

scan函数中最主要的参数有fn,sequences,outputs_info,non_sequences和n_steps,这段示例代码中没有sequences和non_sequences这两个参数,下面仔细讨论一下这几个参数的作用。

  • fn: 这是scan函数最重要的一个参数,即扫描函数,它描述了一步(one_step)scan所做的操作,它会根据输入构造一组输出,而其余的参数sequences,outputs_info,non_sequences都可以作为fn的输入,一次运行也就是fn对这些参数做一次扫描。而则和个scan函数之所以厉害,是因为它不仅仅可以保存一次scan的结果,前几次scan的结果它也可以保存,所以所有这些参数输入fn的顺序为:
    1. 第一个sequence的所有时间片
    2. 第二个sequence的所有时间片
    3. ……
    4. 最后一个sequence的所有时间片
    5. 第一个outputs_info的所有过去的时间片
    6. 第二个outputs_info的所有过去的时间片
    7. ……
    8. 最后一个outputs_info的所有过去的时间片
    9. 所有的non_sequences

      sequences中所有参数的顺序是和输入给fn的顺序一样的,scan的输出变量的顺序是和outputs_info中参数的顺序一样的,看下面一个示例程序:

scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
                                 , Sequence2
                                 , dict(input =  Sequence3, taps = 3) ]
                   , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])
                                    , dict(initial = Output2, taps = None)
                                    , Output3 ]
                   , non_sequences = [ Argument1, Argument2])
           

fn会有得到一下顺序的输入:

1.

Sequence1[t-3]

1.

Sequence1[t+2]

1.

Sequence2[t]

1.

Sequence3[t+3]

1.

Output1[t-3]

1.

Output1[t-5]

1.

Output3[t-1]

1.

Argument1

1.

Argument2

  • sequences:sequences是一个列表或者一系列字典,scan就会在这个sequences上面迭代扫描,如果sequences是一系列字典的话,这些字典应该有两个关键字一个是input,另一个是taps,可以参考上面的代码段,如果不设置tap的话,默认是0,这个参数对于循环神经网络至关重要。
  • outputs_info: 这是一个列表或者字典,它主要用来在每次step中反复初始化scan的输出值,当然这个值也一直在变化,可以参考第一个代码段中,outputs_info是一个六个元素的列表,前5个主要用来占位,第六个用来初始化输出值并且作为fn的输入,当下一次scan时,这个list和输出scan的输出列表是对应的,每进行一步scan操作,outputs_info中的数值会被上一次迭代的输出值更新掉
  • non_sequences: 这个参数也是一个存放输入到fn数据的列表,每一步scan这些数据都会输入到fn中。
  • n_steps:这个参数表示scan执行的次数。
  • 输出参数:scan函数主要有两个输出一个就是fn函数的输出,是一个列表的形式,如代码段1中的gibbs_hvh函数输出了6个变量,还有一个就是updates,这是一个updates的dictionary(that tells how to update any shared variable after each iteration step)指明了scan中用到的所有shared variables的更新规则。

继续阅读