参数
-
fn:
fn
是函数描述了在一步scan.fn
中所有的操作,这个个函数必须构造出描述一步迭代的输出的变量。同样还需要看成是 theano 的输入变量,表示输入序列的所有分片和过去的输出值,以及所有赋给 scan 的non_sequences
的这些其他参数。而 scan 依照如下顺序传递给fn
这些变量: - all time slices of the first sequence
- all time slices of the second sequence
- ...
- all time slices of the last sequence
- all time slices of the last sequence
- all past slices of the first output
- all past slices of the second otuput
- ...
- all past slices of the last output
- all other arguments (the list given as non_sequences to scan)
序列的顺序和在列表 sequences
一致。输出的顺序和 output_info
的顺序相同。对任意序列或者输出,时间分片的顺序和他们给定作为 taps 的序列相同。例如,如果代码写成下面这样:
scan(fn,
sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] ,
outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] ,
non_sequences = [ Argument1, Argument2])
fn
期待的输入顺序是:
- Sequence1[t-3]
- Sequence1[t+2]
- Sequence1[t-1]
- Sequence2[t]
- Sequence3[t+3]
- Output1[t-3]
- Output1[t-5]
- Output3[t-1]
- Argument1
- Argument2
non_sequences
同样包含共享变量,只是 scan
可以将这些变量忽略。为了代码的清晰,我们推荐将这些变量传递给 scan
。在某种程度上,scan
可以确定其他即使没有传递给 scan
(但被 fn
使用的)non_sequences
(not shared) 变量。例如:
import theano.tensor as TT
W = TT.matrix()
W_2 = W**2
def f(x):
return TT.dot(x,W_2)
函数需要返回两个东西。一个是按照 outputs_info
顺序排列的输出列表,不同的是对每个输出初始状态只有一个一个输出变量对应(即使没有使用 tap 值)。第二个 fn
应当返回一个更新字典(告诉程序如何对共享变量进行每步的更新)。字典也可以以 tuple 的列表给出。对于这两个列表的顺序倒没有限制,fn
可以返回 (outputs_list, update_dictionary)
或者 (update_dictionary, outputs_list)
,或者就其中之一(另一个为空)。将 scan
当成是一个 while 循环,我们需要给 fn
增加一个退出循环的条件——将条件配置在一个 until
类中欧诺个。这些条件必须被返回为第三个元素,如下:
...
return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x<50)
注意,步数(最大迭代步骤数)仍然需要指定,即使一个有了一个条件
-
sequences: 序列是 Theano 变量或者字典的列表,告诉程序
scan
必须迭代的序列。如果序列以字典的形式给出,那么可选信息集合可给这个序列。字典应该包含如下的关键信息: -
input
(强制的)表示序列的 Theano 变量 -
taps
fn
需要的时间片。通常以整数列表的方式提供,其中k
的值表示一个迭代步骤t
scan 会 传递给fn
序列片t+k
。默认值为[0]
任何在 sequence
列表的 Theano 变量都会自动封装成一个字典,其 taps
被设置为 [0]
-
output_info
outputs_info
是 Theano 变量或者字典的列表,给出了递归计算的输出初始时的状态。当这个初始状态给定为字典时,说明了对应这些初始状态的输出的可选信息。字典应当包含下面的元素: -
initial
表示一个给定输出的初始状态的 Theano 变量。如果输出不是递归计算的(如 map)或者不需要初始状态,那么这里可以跳过。由fn
前面时间步的输出,初始状态应该拥有和输出的同样形状,并且不能够包含输出的数据类型的转换。如果使用多时间 tap,初始状态应当由额外的维度来覆盖所有可能的 tap。例如,如果我们使用-5, -2, -1
作为过去的 tap,在第 0 步,fn
会需要output[-5], output[-2] output[-1]
。这将由初始状态给出,这里的形状就是(5,)+ output.shape
。如果这个包含初始状态的变量称为init_y
那么init_y[0]
对应于output[-5]
。init_y[1]
对应于output[-4]
。init_y[2]
对应于output[-3]
。init_y[3]
对应于output[-2]
。init_y[4]
对应于output[-1]
。这个顺序可能看起来奇怪,不过这来自给定点的数组划分,也有相应的道理。假设我们有一个数组x
,选择k
为时间步0
。那么初始的状态就是x[:k]
,而输出就是x[k:]
。看看这个划分,在x[:k]
中的元素顺序和init_y
中完全一致。 -
taps
输出的时间 tap 将会被传递给fn
。他们是以负整数的列表给出,其中k
表示在迭代步t
scan 会将切片t+k
传递给fn
。
scan
会按照下面的规则进行: - 如果输出现在封装在一个字典中,
scan
将会按照你仅仅在输出的最后一步使用他这个前提封装它(即让你的 tap 值设置为[-1]
) - 如果你在一个字典中封装一个输出,并且你不提供任何的 tap 但是提供了一个初始状态,那么会假设你仅仅使用 tap 值为 -1.
- 如果你将输出封装进一个字典中,不过你没有提供任何的初始状态,那么会假设你不回使用任何形式的 tap
- 如果你提供
None
而非一变量或者一个空字典,那么scan
假设你将不会对这个输出使用任何 tap(就像在 map 中那样)
如果 outputs_info
是一个空列表或者 None
,scan
假设了没有 tap 用在任何输出上。如果信息仅仅针对输出的子集给出,那么会抛出一个异常(因为并没有给出 scan 如何映射信息给 fn
的输出的默认行为)
-
non_sequences
non_sequences
是在每一步被传递给fn
的参数的列表。我们可以可选择地将fn
中使用的变量用此列表剔除,尽管为了代码清晰不建议这么做。 -
n_steps
n_steps
是以 int 或者 Theano scalar 给出的迭代步数。如果任何输入序列没有足够的元素,scan 会给出一个错误。如果值为 0 输出将只有 0 行。如果值为负值,scan
会往回运行。如果go_backwards
flag 已经设置了,而且n_steps
是负值,scan
将会向前运行。如果n_steps
没有给出,scan
将在给定输入序列时就会搞清楚应当运行的步数。 -
truncate_gradient
truncate_gradient
是用在 truncated BPTT 上的步数。如果你通过 scan op 来计算梯度,他们会使用 BPTT 来计算。通过给定不同于 -1 的值,你将确定使用 truncated BPTT 而非经典的 BPTT -
go_backwards
go_backwards
是表示scan
是否往回走的标志。如果你将每个句子看做按照时间标记,让这个标志设置为 True 会让scan
按照时间往回扫描。 -
name 在对
scan
进行性能分析时,给每个scan
的实例进行命名是很重要的。性能分析器将产生整体的代码分析,甚至每个scan
实例步骤的分析。实例的name
则出现在这些分析中,提供了具有区分度的信息。 -
mode 推荐将此设置为 None,特别是对
scan
进行分析的时候(否则结果会不准确)。如果你倾向 scan 的某一步计算使用某种特殊的方式计算,可以使用 mode 来改变计算行为(参考theano.function
来看看可能的使用方式) -
profile Flag 或者 string。如果为 True,或者不同于一个空串,那么就会创建一个分析器对象,并绑定在 scan 的 inner 计算图上。如果
profile
设置为 True,该对象会有一个 scan 实例的名字,否则就使用传递的 string。分析器对象仅仅会在使用新 cvm链接器运行 inner 计算图的时候收集(并打印)信息(按照默认模式,对其他链接器,这个参数就是无用的) -
allow_gc 设置此项可以允许 scan 的内部计算图进行 gc。如果为 None,就会使用
config.scan.allow_gc
的值。 -
strict 如果设置为 True,
fn
中所有共享变量都必须作为non_sequences
或者sequences
的一部分提供。
返回值
形为 (outputs, updates) 的元组,outputs 是 Theano 的变量或者 Theano 变量的列表,表示 scan 的输出(按照 outputs_info
的顺序)。updates
是一个字典的子类指定了所有共享变量的更新规则。这个字典应该被传递给 theano.function
。不同于正常的字典的是我们验证这些 key 为 SharedVariable 并且确保这些字典的求和是一致的。
返回类型
元组(tuple)