天天看点

theano.scan

theano.scan(fn, 
            sequences=None, 
            outputs_info=None, 
            non_sequences=None, 
            n_steps=None, 
            truncate_gradient=-1, 
            go_backwards=False, 
            mode=None, 
            name=None, 
            profile=False, 
            allow_gc=None, 
            strict=False)
           

主要参数的含义:

fn :一步 scan 所进行的操作

sequences :输入的序列

outputs_info:前一步输出结果的初始状态

non_sequences:非序列参数,即每次迭代都要用到此值,若为矩阵,每次用到此矩阵,和序列不一样,每次只用到序列的一个值

n_steps:迭代步数

go_backwards:是否从后向前遍历
           

输出为一个元组 (outputs, updates):

outputs:从初始状态开始,每一步 fn 的输出结果

updates:一个字典,用来记录 scan 过程中用到的共享变量更新规则,构造函数的时候,如果需要更新共享变量,将这个变量当作 updates 的参数传入。
           

fn 是一个函数句柄,对于这个函数句柄,它每一步接受的参数是由 sequences, outputs_info, non_sequence 这三个参数所决定的,默认情况下,在第 k 次迭代时,如果 sequences 和 outputs_info 中给定的值不是字典(dictionary)或者一个字典列表(list of dictionaries),那么

sequences 中的序列 seq 传入 fn 的是 seq[k] 的值
outputs_info 中的序列 output 传入 fn 的是 output[k-1] 的值
           

fn 的返回值有两部分 (outputs_list, update_dictionary),第一部分将作为序列,传入 outputs 中,与 outputs_info 中的初始输入值的维度一致(如果没有给定 outputs_info ,输出值可以任意。)

第二部分则是更新规则的字典,告诉我们如何对 scan 中使用到的一些共享的变量进行更新:

return [y1_t, y2_t], {x:x+1}

这两部分可以任意,即顺序既可以是 (outputs_list, update_dictionary), 也可以是 (update_dictionary, outputs_list),theano 会根据类型自动识别。

两部分只需要有一个存在即可,另一个可以为空。

具体例子见GitHub/loop with scan