xiangze's sparse blog

機械学習、ベイズ統計、コンピュータビジョンと関連する数学について

theanoでwhileなど

sinhrksさんがtheanoでのloopについてまとめていたので
その落ち穂拾いです。sinhrks.hatenablog.com

while文

ある条件のときに繰り返しを停止する普通のプログラミング言語のwhile文のような機能は
theano.scan_module.until
を使います。theano.scan_module.untilを返す関数をscanに代入しなければいけないのでめんどくさいです。
例:スカラーを指定値になるまで倍々にしていく*1

def power_of_2(p, m):
    return p*2, theano.scan_module.until(p*2 > m)

max_value = T.scalar()
values, _ = theano.scan(power_of_2,
                        outputs_info = T.constant(1.),
                        non_sequences = max_value,
                        n_steps = 1024)

f = theano.function([max_value], values)
print f(45)

theano.scan_moduleにはこの他にも興味深い関数があるようです。
https://github.com/Theano/Theano/blob/master/theano/scan_module/scan_utils.py

theano.scanの引数

sinhrksさんの紹介にあった

  • fn 繰り返すべき関数
  • sequences 繰り返しによって更新される変数(iteration,map型の使い方で指定)
  • outputs_info 初期値
  • non_sequences 繰り返しの間変化させない入力変数(loop型の使い方で指定)
  • n_steps 繰り返す回数

の他にあまり使わないかもしれませんが、

  • truncate_gradient

RNNの更新規則であるBPTT(Back Propagation Through Time)の反復回数を指定する

  • go_backwards

sequenceを逆に辿ってscanを実行する。

  • allow_gc

ガーベージコレクションを有効にする。falseにするとGPUでの計算が速くなるらしい。
などの引数があります。

map,reduceなど

scanはやたら色々なことが出来るので逆に使いづらいかもしれません。他の言語同様にmap, reduce, foldl, foldrがあります。
例(map):

X=T.scalar('X')
def f(x,y):
     return x*y

fm,_=theano.map(f,
                sequences=np.array([3, 4, 5]), 
                non_sequences=X
                )

func = theano.function(inputs=[X], outputs=fm) 
print func(4) 

mapはnon_sequencesに定数列のsequencesを適応するのに使う(?)
定数じゃなくてもいいらしいです。
http://nbviewer.ipython.org/gist/xiangze/0f800305ee67b229577a/Theano_scan_loop_and_ite.ipynb

普通のpythonのreduceよりtheano.reduceのほうが速いという話です。
https://groups.google.com/forum/#!topic/theano-users/SCKFg-UltQo

一応ネストしたscanとか出来るらしいです。

X = T.matrix("X")

res0, updates = theano.scan(lambda x_i: (x_i ** 2).sum(), sequences=[X])
res,up2       = theano.scan(lambda x_i: (x_i ** 2).sum(), sequences=[res0])
resT=res0.sum()
norm = theano.function(inputs=[X], outputs=[res0,resT])
x = np.matrix([[1,2,3],[4,5,6],[1,2,3]], dtype=theano.config.floatX)
print norm(x)