Recurrent Switching Linear Dynamical Systemsのstanでの再現
Stan Advent Calendar12/11の記事です。時系列の統計モデルに関する論文「Recurrent Switching Linear Dynamical Systems」の紹介とそのstanでの再現です、が現状うまくいっていません。。
モデル
この論文で提唱されているモデルRecurrent Switching State Space Model(rSSS)、その一種であるRecurrent Switching linear dynamical systems(rSLDS)は離散的な状態を再帰的に推定することが特徴とされています。
元となっているSwitching linear dynamical systems(SLDS)は線形の状態空間モデルを離散的な状態の数だけとってその間の遷移を許容するようなモデルです。
観測値をy_t,内部状態をx_tとする線形の状態空間モデルは
(MNIWはMatrix normal invert Wishart分布)
と書かれ、これがK個の状態z[k]_tごとに存在すると
と書かれます。rSLDSはこれに離散的な状態z_tの時間発展のモデル
を組み合わせることで構成されます。ここでπ_{SB}はsigmoid関数を用いてk成分目が
と書かれる関数でstick-breaking processと呼ばれています。条件分岐のようになっていて状態数が多い場合先のものから選ばれていくような形になっています。
論文では同じ著者らによる補助変数(augmentate variable)とPolya-Gamma分布
による補助変数のサンプリングを用いた手法(
DEPENDENT MULTINOMIAL MODELS MADE EASY: STICK BREAKING WITH THE POLYA-GAMMA AUGMENTATION )
(wが補助変数)で計算しており、大規模なデータに対しても適用可能ともしていますが、ここでは直接stanを用いて再現することを試みました。
コード
stanでは
functions { real sigmoid(real a){ return inv_logit(-a); } } data{ int T; int N; int M; int K; vector [T] y[N]; } parameters{ matrix[M,M] A[K]; vector [K] C; vector [M] R[K]; vector [M] b[K]; vector [K] r; vector [M] x[N]; real d[K]; real <lower=0> s; cholesky_factor_corr[M] corr_ch; vector<lower=0> [M] sv; } transformed parameters{ cholesky_factor_cov[M] cov_ch; cov_ch<-diag_pre_multiply(sv,corr_ch); } model{ vector [K] z; for(j in 1:K){ sv[j]~student_t(4,0,200); } corr_ch~lkj_corr_cholesky(2); for(t in 1:T){ for(i in 2:N){ for(k in 1:K){ target+=log(z[k])+multi_normal_cholesky_lpdf(x[i]|A[k]*x[i-1]+b[k],cov_ch) ; target+=log(z[k])+normal_lpdf(y[i,t]|C[k]*x[i]+d[k],s) ; } z[1]<-sigmoid(dot_product(R[1],x[i])+r[1]); for(k in 2:K){ z[k]<-sigmoid(dot_product(R[k],x[i])+r[k]); for(kp in k:K){ z[k]<-z[k]+sigmoid(-(dot_product(R[kp],x[i])+r[kp])); } } target+=log_sum_exp(log(z)); } } }
と書かれます。離散的な状態に値をassignする為modelブロックの中で定義しています。これを出力するにはgenerated quantitiesブロックで何か変換する必要がありそうです。
zをsimplexまたはupper,lowerの拘束があるvectorにしようとすると require unconstrained variable declaration とコンパイルエラーになってしまいます(追記)。
数値実験
論文では車の走行のデータベースNASCARで見られる軌跡、バスケットボールの試合での選手の動きの状態推定*1とLorenz方程式に従う時系列の状態推定を行っています。
ここではLorenz方程式
の出力結果の再現を試みました。
推定ではLorenz方程式そのものをモデル化することなく、蝶の羽のような形をしたアトラクタの左右どちらにいるかを2値の状態の観測結果としてそこから内部状態のダイナミクスを推測しています。すなわち観測モデルがベルヌーイ分布Bernを用いて
(tは時間に関する、nはサンプルに関する添字)と書けるとしています。各状態に対応するモデルは線形なのですが、左右のどちらかにいる場合は線形の微分方程式で近似できるような動きをしており、その間を遷移する場合に予測しがたい動きが生じるようになっていて、今回のモデルの使用に適しているようです*2。
ただし論文では1本ではなく同じ時系列に対する多数(100個の)観測結果を取得しています*3。また時系列の中の一部を隠した場合にも推測が可能であることを示しています(Fig 4 ( c)(d)(e)(f)の灰色部分)。内部状態をprobablistic PCAで初期化しているそうです*4。
計算結果
まず観測数T=20,時系列の長さN=500の場合をpythonで合成した時系列データから状態を推定しようとしました。
するとコンパイルが成功したのち
Initialization between (-2, 2) failed after 100 attempts.
Try specifying initial values, reducing ranges of constrained values, or reparameterizing the model.
となり、サンプリングが行われません。そこでまずstan マニュアルや「StanとRでベイズ統計モデリング」p. 192で言われているようにinvert wishert分布の代わりに
LKJ相関分布というものを使ってみました。また自動変分ベイズを使った場合も試してみました。しかしそれでも上記のメッセージの表示は変わらずまた
Rejecting initial value:
Log probability evaluates to log(0), i.e. negative infinity.
Stan can't start sampling from this initial value.
とあるので論文同様初期化を行うべきなのだと考えられます。そこで反則なのですがまず初期値として内部状態の値を使ってみました。
しかしながら
RuntimeError: Rejecting initial value:
Log probability evaluates to log(0), i.e. negative infinity.
Stan can't start sampling from this initial value.
が発生してしまいはまっている感じです。
以下が使用したコードです。
import pystan import numpy as np isvb=True isinit=True print(pystan.__version__) smodel=pystan.StanModel(file="RSLDS.stan") oN=500 oT=10 y=np.loadtxt("data.csv",delimiter=",") N=min(y.shape[0],oN) T=min(y.shape[1],oT) y=y[:N,:T] M=3 K=2 data={"T":T,"N":N,"M":M,"K":K,"y":y} if(isinit): xs=np.loadtxt("data_source.csv",delimiter=",") xs=xs[:N,:M] def initf(): return dict(x=xs, A=np.random.normal(size=(K,M,M)), C=np.random.normal(size=K), R=np.random.normal(size=(K,M)), b=np.random.normal(size=(K,M)), r=np.random.normal(size=K), d=np.random.normal(size=K), s=1, corr_ch=np.random.uniform(-1,1,size=(M,M)), sv=np.repeat(1.,M) ) else: initf="random" if(isvb): fit= smodel.vb(data=data,init=initf,seed=71) else: fit= smodel.sampling(data,init=initf,iter=31000, warmup=1000, thin=10, chains=3, seed=71) res=fit.extract() print res import draws draws.drawline()
参考
Polya-Gamma分布による離散状態の推定