Torch7のCNNのFPGA実装は可能か(絵に描いた餅編)
waifu2xの登場で注目されるTorchですが、様々なアーキテクチャでの実装を標榜しているようです。
http://torch.ch/
ではFPGA backendsと書かれていますが、誰かが実装したという話は聞いたことがありません。ので検討してみました。
まずwaifu2xのアルゴリズムとパラメータの傾向を把握します。
アルゴリズムはpythonで書かれたものがわかりやすかったです。
https://marcan.st/transf/waifu2x.py
Y(輝度)成分だけに超解像処理を適用していて、入力は元画像のpixelを単純に繰り返したもののようです。
またjson化された係数の分布を見ると0になっている部分はほとんどないようです。
http://nbviewer.ipython.org/gist/xiangze/0a303baca4c99d463748
sparsenessが過学習の抑制や多層でのバックプロパゲーションが有効である理由はConvolutionが局所的であることで実現できていると理解できそうです。
CNNの構成は
- l=layer(層)の数
- n1=入力channel数
- n2=出力channel数
- f1 Convolutionの縦サイズ
- f2 Convolutionの横サイズ
- F=w*h=縦w,横hの画像サイズ
というパラメータで表されます。これをFPGAで並列実装することでリアルタイム超解像処理ができないかと考えました。
verilogで3x3の畳み込み(+ReLU)を書いたものが以下です。
https://github.com/xiangze/CNN_FPGA
module prodsum3( input clk, input resetn, input signed [9:0] w0, input signed [9:0] w1, input signed [9:0] w2, input signed [9:0] x0, input signed [9:0] x1, input signed [9:0] x2, input clip, output signed reg [9:0] out ); wire [20:0] no; wire [9:0] no_clip; assign no=x0*w0+x1*w1+x2+w2; assign no_clip=$signed( (no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]}); always@(posedge clk or negedge resetn) if(!resetn) out<=0; else if(clip) out<=no_clip; else out<={no[20:11]}; endmodule module filter3x3( input clk, input resetn, input relu, input [7:0] relu_c, {% for i in fn %} {% for j in fn %} input signed [9:0] w{{j}}_{{i}}, {% endfor %} {% endfor %} input signed [9:0] b, {% for i in fn %} {% for j in fn %} input signed [9:0] x{{j}}_{{i}}, {% endfor %} {% endfor %} output signed reg [9:0] out ); assign clip=1; {% for i in fn %} wire [9:0] out{{i}}; prodsum3 p0(.clk(clk), .resetn(resetn),.clip(clip), {% for j in fn %} .w{{j}}(w{{j}}_{{i}}[9:0]), .x{{j}}(x{{j}}_{{i}}[9:0]), {% endfor %} .out(out{{i}}[9:0])); {% endfor %} wire [20:0] no; wire [9:0] no_clip; wire [17:0] no_nrelu; assign no=out0+out1+out2+b; assign no_clip=$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]}); assign no_nrelu=no_clip*relu_c; always@(posedge clk or negedge resetn) if(!resetn) out<=0; else if(relu) out<=(no_clip[20]?0:no_clip)+(~no_clip[20]?0:no_nrelu[17:8]); else out<=no_clip; endmodule
これを1演算単位とします。ReLuは計算が簡単なこともあり固定小数点で実装しました。
verilogのgenerate文を使って繰り返しを記述する事もできますが、機能に制限があるようなのでjinja2を使って書いています。
これを入力channel数の数だけ並列化したものが以下のようなコードになります。
module filter_n1( input clk, input resetn, input clip, input relu, input [7:0] relu_c, {% for k in n1 %} {% for i in fn %} {% for j in fn %} input signed [{{width-1}}:0] w{{k}}_{{j}}_{{i}}, {% endfor %} {% endfor %} {% endfor %} {% for k in n1 %} {% for i in fn %} {% for j in fn %} input signed [{{width-1}}:0] x{{k}}_{{j}}_{{i}}, {% endfor %} {% endfor %} {% endfor %} {% for k in n1 %} input signed [{{width-1}}:0] b{{k}}; {% endfor %} output signed reg [9:0] out ); {% for k in n1 %} wire [{{width-1}}:0] out{{k}}; filter3x3 filter_{{k}}( .clk(clk), .resetn(resetn), .relu(0),.relu_c(0), {% for i in fn %} {% for j in fn %} .w_{{j}}_{{i}}(w{{k}}_{{j}}_{{i}}[{{width-1}}:0]), .x_{{j}}_{{i}}(x{{k}}_{{j}}_{{i}}[{{width-1}}:0]), {% endfor %} {% endfor %} .b(b{{k}}[{{width-1}}:0]) .out(out{{k}}[{{width-1}}:0])); {% endfor %} {% for l in ln1 %} {% for k in ii[l] %} reg [{{width-1}}:0] out_{{l}}_{{k}}; wire [{{width*2}}:0] nout_{{l}}_{{k}}; wire [{{width-1}}:0] nout_clip_{{l}}_{{k}}; assign nout_{{l}}_{{k}} =out_{{l-1}}_{{k*4}}+out_{{l-1}}_{{k*4+1}}+out_{{l-1}}_{{k*4+2}}+out_{{l-1}}_{{k*4+3}}; assign no_clip_{{l}}_{{k}} =$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]}); always@(posedge clk or negedge resetn) if(!resetn) out_{{l}}_{{k}}<=0; else if(clip) out<=no_clip_{{l}}_{{k}} else out<={no_{{l}}_{{k}}[20:11]}; {% endfor %} {% endfor %} wire [{{width*2}}:0] no; wire [{{width-1}}:0] no_clip; wire [{{width-1+8}}:0] no_nrelu; assign no=out_0+out1+out2+out3; assign no_clip=$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]}); assign no_nrelu=no_clip*relu_c; always@(posedge clk or negedge resetn) if(!resetn) out<=0; else if(relu) out<=(no_clip[{{width-1}}]?0:no_clip)+(~no_clip[{{width-1}}]?0:no_nrelu[{{width-1+8}}:8]); else if(clip) out<=no_clip; else out<={no[20:11]}; endmodule
同様に出力channel数n2だけ並列化させたもの、layer数lだけパイプライン化した構成が考えられます。
https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_n2.v
https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_l.v
画像信号が1クロックに1ピクセル分だけ入ってくるとすると3x3フィルターの上下左右のピクセルの値を保持する必要があり、レジスタ、ラインバッファを使った記述になります。
https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_n2_line.v
リアルタイム超解像処理には全層、全チャネルの並列化が必要です。すべての演算単位を同時に回路化しようとすると回路規模が膨大になってしまい、既存のFPGAには入りきらなくなってしまいます*1。より現実的な構成として以下のようなものが考えられます。
構成のバリエーション案
- l*n1*n2構成
最初に考えた上記の構成です。画像サイズ*channel数の中間データを保存する必要のない構成ですが、演算単位がl*n1*n2だけ必要になってしまい非現実的です。畳み込みパラメータはハードコーディングすることが出来そうです。
縦横3x3のフィルタを適用するためにラインバッファ、レジスタが必要です。
- n1*n2構成
filter_n2をl回使いまわす構成です。層ごとの演算結果をメモリに保存し、次の層の計算で読み出します。
1/lだけスループットが低下してしまいます。畳み込みパラメータは別途メモリまたはハードコーディングで用意し、計算している層に応じて切り替えて供給することになります。
- n2構成
filter_n1をl*n2回使いまわす構成です。スループットは1/(l*n2)になります。
同じ入力で異なる出力を得られるもっとも並列化の恩恵が大きい構成化と思われます。
- 1構成
channel並列化を行わない場合です。l*n1*n2回繰り返しを行わないといけません。普通のCPUでやったほうがよい構成です。
演算回路の並列化よりもメモリの帯域のほうが重要かもしれないのですがどうでしょう。
その他
メモリへの読み書きと状態を制御する回路を設計するのは面倒かもしれません。AlteraだとをUniPHYを使うのが良いかもしれません。
XilinxのVertexのほうがDSPが多く、より大きなネットワークが実現できるかもしれません。
ReLuは式が簡単なのでバックプロパゲーションもFPGAで実施できるかもしれないと思いました。
Reference
waifu2x
http://ultraist.hatenablog.com/entry/2015/05/17/183436
http://waifu2x.udp.jp/
https://github.com/nagadomi/waifu2x
Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang, "Image Super-Resolution Using Deep Convolutional Networks"
http://arxiv.org/pdf/1501.00092v2.pdf
waifu2xのpython実装
https://marcan.st/transf/waifu2x.py
パラメータ
https://marcan.st/transf/scale2.0x_model.json
使用例
http://kiito.hatenablog.com/entry/2015/05/30/181700
サルでも分かる waifu2x のアルゴリズム
https://drive.google.com/file/d/0B22mWPiNr-6-RVVpaGhZa1hJTnM/view
FPGA
CiNii 論文 - 同期シフトデータ転送に基づく Deep Convolutional Neural Network のFPGA実装
Hardware Accelerated Convolutional Neural Networks for Synthetic Vision Systems
http://yann.lecun.com/exdb/publis/pdf/farabet-iscas-10.pdf
Large-Scale FPGA-based Convolutional Networks
http://yann.lecun.com/exdb/publis/pdf/farabet-suml-11.pdf
MicrosoftがFPGAでDeepLearningしてた - SANMAN
Xilinx,Spartanで歩行者検出のデモを行っているものだそうです。ReLUを使用し、学習は別のデバイスで行ったそうです。
デンソーによるDeep Neural NetworkのFPGA実装!#SSII2015 pic.twitter.com/Uxd9n1YaKU
— Hironobu Fujiyoshi (@hf149) June 11, 2015
RNNでもReLuが使えるという話です。feedbackがあるので初期値が重要だそうです。
RNNにReLUを使う、勾配が消失せず長期間、同じ状態を保てるように、回帰結合を単位行列で初期化するだけでLSTMと同じ精度が出せ、それより長期間の依存関係も学習できるという衝撃的な報告。言われればそうかと思うが誰もやらなかった http://t.co/o2Z9EfLRdJ
— 岡野原 大輔 (@hillbig) 2015, 4月 8
[1504.00941] A Simple Way to Initialize Recurrent Networks of Rectified Linear Units