Theano的scan解析
前面讲过function函数,可以利用它实现梯度更新,重点就是里面的updates参数,并且还可以根据模型的输入inputs和givens参数以及updates的shared_variable也就是模型参数得到模型输出也就是cost。而theano另外一个很重要的函数就是scan,它具有loop的效果。从函数名称可以看出来这个函数的作用就是扫描,是对提供的参数进行某种扫描操作,也就是不断的更新。
1. 这是一段截取自RBM示例程序的有关scan的代码
(
[
pre_sigmoid_nvs,
nv_means,
nv_samples,
pre_sigmoid_nhs,
nh_means,
nh_samples
],
updates # is a subclass of dictionary specifying the update rules
# for all shared variables used in scan
) = theano.scan(
self.gibbs_hvh, # fn
# the None are place holders, saying that
# chain_start is the initial state corresponding to the
# 6th output
outputs_info=[None, None, None, None, None, chain_start],
#chain_start's tap value is [-1]
n_steps=k,
name="gibbs_hvh"
)
scan函数中最主要的参数有fn,sequences,outputs_info,non_sequences和n_steps,这段示例代码中没有sequences和non_sequences这两个参数,下面仔细讨论一下这几个参数的作用。
- fn: 这是scan函数最重要的一个参数,即扫描函数,它描述了一步(one_step)scan所做的操作,它会根据输入构造一组输出,而其余的参数sequences,outputs_info,non_sequences都可以作为fn的输入,一次运行也就是fn对这些参数做一次扫描。而则和个scan函数之所以厉害,是因为它不仅仅可以保存一次scan的结果,前几次scan的结果它也可以保存,所以所有这些参数输入fn的顺序为:
- 第一个sequence的所有时间片
- 第二个sequence的所有时间片
- ……
- 最后一个sequence的所有时间片
- 第一个outputs_info的所有过去的时间片
- 第二个outputs_info的所有过去的时间片
- ……
- 最后一个outputs_info的所有过去的时间片
-
所有的non_sequences
sequences中所有参数的顺序是和输入给fn的顺序一样的,scan的输出变量的顺序是和outputs_info中参数的顺序一样的,看下面一个示例程序:
scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])
, Sequence2
, dict(input = Sequence3, taps = 3) ]
, outputs_info = [ dict(initial = Output1, taps = [-3,-5])
, dict(initial = Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument2])
fn会有得到一下顺序的输入:
1.
Sequence1[t-3]
1.
Sequence1[t+2]
1.
Sequence2[t]
1.
Sequence3[t+3]
1.
Output1[t-3]
1.
Output1[t-5]
1.
Output3[t-1]
1.
Argument1
1.
Argument2
- sequences:sequences是一个列表或者一系列字典,scan就会在这个sequences上面迭代扫描,如果sequences是一系列字典的话,这些字典应该有两个关键字一个是input,另一个是taps,可以参考上面的代码段,如果不设置tap的话,默认是0,这个参数对于循环神经网络至关重要。
- outputs_info: 这是一个列表或者字典,它主要用来在每次step中反复初始化scan的输出值,当然这个值也一直在变化,可以参考第一个代码段中,outputs_info是一个六个元素的列表,前5个主要用来占位,第六个用来初始化输出值并且作为fn的输入,当下一次scan时,这个list和输出scan的输出列表是对应的,每进行一步scan操作,outputs_info中的数值会被上一次迭代的输出值更新掉
- non_sequences: 这个参数也是一个存放输入到fn数据的列表,每一步scan这些数据都会输入到fn中。
- n_steps:这个参数表示scan执行的次数。
- 输出参数:scan函数主要有两个输出一个就是fn函数的输出,是一个列表的形式,如代码段1中的gibbs_hvh函数输出了6个变量,还有一个就是updates,这是一个updates的dictionary(that tells how to update any shared variable after each iteration step)指明了scan中用到的所有shared variables的更新规则。