読者です 読者をやめる 読者になる 読者になる

xiangze's sparse blog

機械学習、ベイズ統計、コンピュータビジョンと関連する数学について

Torch7のCNNのFPGA実装は可能か(絵に描いた餅編)

waifu2xの登場で注目されるTorchですが、様々なアーキテクチャでの実装を標榜しているようです。
http://torch.ch/
ではFPGA backendsと書かれていますが、誰かが実装したという話は聞いたことがありません。ので検討してみました。


まずwaifu2xのアルゴリズムとパラメータの傾向を把握します。
アルゴリズムpythonで書かれたものがわかりやすかったです。
https://marcan.st/transf/waifu2x.py
Y(輝度)成分だけに超解像処理を適用していて、入力は元画像のpixelを単純に繰り返したもののようです。
またjson化された係数の分布を見ると0になっている部分はほとんどないようです。
http://nbviewer.ipython.org/gist/xiangze/0a303baca4c99d463748

sparsenessが過学習の抑制や多層でのバックプロパゲーションが有効である理由はConvolutionが局所的であることで実現できていると理解できそうです。

CNNの構成は

  • l=layer(層)の数
  • n1=入力channel数
  • n2=出力channel数
  • f1 Convolutionの縦サイズ
  • f2 Convolutionの横サイズ
  • F=w*h=縦w,横hの画像サイズ

というパラメータで表されます。これをFPGAで並列実装することでリアルタイム超解像処理ができないかと考えました。

verilogで3x3の畳み込み(+ReLU)を書いたものが以下です。
https://github.com/xiangze/CNN_FPGA

module prodsum3(
            input           clk,
            input           resetn,
            input signed [9:0] w0,
            input signed [9:0] w1,
            input signed [9:0] w2,
            input signed [9:0] x0,
            input signed [9:0] x1,
            input signed [9:0] x2,
            input           clip,
            output signed      reg [9:0] out
            );

   wire [20:0]                 no;
   wire [9:0]                 no_clip;

   assign no=x0*w0+x1*w1+x2+w2;
   assign no_clip=$signed( (no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]});
  
   always@(posedge clk or negedge resetn)
     if(!resetn)
       out<=0;
     else if(clip)  
       out<=no_clip;
     else
       out<={no[20:11]};      

endmodule

module filter3x3(
          input               clk,
          input               resetn,
          input               relu,
          input [7:0]          relu_c,
         
          {% for i in fn %} {% for j in fn %}
          input signed [9:0] w{{j}}_{{i}},
          {% endfor %} {% endfor %}
          input signed [9:0] b,
          {% for i in fn %} {% for j in fn %}         
          input signed [9:0] x{{j}}_{{i}},
          {% endfor %} {% endfor %}
          output signed          reg [9:0] out
          );

   assign               clip=1;
  
   {% for i in fn %}
   wire [9:0]                 out{{i}};
   prodsum3 p0(.clk(clk), .resetn(resetn),.clip(clip),
            {% for j in fn %}
            .w{{j}}(w{{j}}_{{i}}[9:0]),
            .x{{j}}(x{{j}}_{{i}}[9:0]),
            {% endfor %}
            .out(out{{i}}[9:0]));
   {% endfor %}

   wire [20:0]                 no;
   wire [9:0]                 no_clip;
   wire [17:0]                 no_nrelu;
    
   assign no=out0+out1+out2+b;
   assign no_clip=$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]});
   assign no_nrelu=no_clip*relu_c;
  
   always@(posedge clk or negedge resetn)
     if(!resetn)
       out<=0;
     else if(relu)  
       out<=(no_clip[20]?0:no_clip)+(~no_clip[20]?0:no_nrelu[17:8]);
     else
       out<=no_clip;

endmodule

これを1演算単位とします。ReLuは計算が簡単なこともあり固定小数点で実装しました。
verilogのgenerate文を使って繰り返しを記述する事もできますが、機能に制限があるようなのでjinja2を使って書いています。
これを入力channel数の数だけ並列化したものが以下のようなコードになります。

module filter_n1(
            input                  clk,
            input                  resetn,
            input                  clip,
            input                  relu,
            input [7:0]                  relu_c,
                                  {% for k in n1 %}
                                  {% for i in fn %} {% for j in fn %}
            input signed [{{width-1}}:0] w{{k}}_{{j}}_{{i}},
                                {% endfor %} {% endfor %} {% endfor %}
                                {% for k in n1 %}
                                {% for i in fn %} {% for j in fn %}
            input signed [{{width-1}}:0] x{{k}}_{{j}}_{{i}},
                                {% endfor %} {% endfor %}
                                {% endfor %}
         
                                  {% for k in n1 %}
            input signed [{{width-1}}:0] b{{k}};
            {% endfor %}
            output signed reg [9:0] out
          );

   {% for k in n1 %}  
     wire [{{width-1}}:0] out{{k}};
     filter3x3 filter_{{k}}(
                   .clk(clk), .resetn(resetn),
                   .relu(0),.relu_c(0),
                     {% for i in fn %} {% for j in fn %}
                   .w_{{j}}_{{i}}(w{{k}}_{{j}}_{{i}}[{{width-1}}:0]),
                   .x_{{j}}_{{i}}(x{{k}}_{{j}}_{{i}}[{{width-1}}:0]),
                   {% endfor %}   {% endfor %}
                   .b(b{{k}}[{{width-1}}:0])
                   .out(out{{k}}[{{width-1}}:0]));
   {% endfor %}
    
   {% for l in ln1 %}  
     {% for k in ii[l] %}  
       reg [{{width-1}}:0]  out_{{l}}_{{k}};
       wire [{{width*2}}:0] nout_{{l}}_{{k}};
       wire [{{width-1}}:0] nout_clip_{{l}}_{{k}};
  
       assign nout_{{l}}_{{k}}
     =out_{{l-1}}_{{k*4}}+out_{{l-1}}_{{k*4+1}}+out_{{l-1}}_{{k*4+2}}+out_{{l-1}}_{{k*4+3}};
  
       assign no_clip_{{l}}_{{k}}
     =$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]});
  
       always@(posedge clk or negedge resetn)
     if(!resetn)
        out_{{l}}_{{k}}<=0;
     else if(clip)  
        out<=no_clip_{{l}}_{{k}}
     else
        out<={no_{{l}}_{{k}}[20:11]};      
   {% endfor %}
   {% endfor %}
         
   wire [{{width*2}}:0] no;
   wire [{{width-1}}:0] no_clip;
   wire [{{width-1+8}}:0] no_nrelu;
  
   assign no=out_0+out1+out2+out3;
   assign no_clip=$signed((no[20]&(!&no[19:18]))?-512:(!no[20]&|no[19:18])?512:{no[20],no[17:9]});
   assign no_nrelu=no_clip*relu_c;
  
   always@(posedge clk or negedge resetn)
     if(!resetn)
       out<=0;
     else if(relu)  
       out<=(no_clip[{{width-1}}]?0:no_clip)+(~no_clip[{{width-1}}]?0:no_nrelu[{{width-1+8}}:8]);
     else if(clip)  
       out<=no_clip;
     else
       out<={no[20:11]};      

endmodule

同様に出力channel数n2だけ並列化させたもの、layer数lだけパイプライン化した構成が考えられます。
https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_n2.v

https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_l.v
画像信号が1クロックに1ピクセル分だけ入ってくるとすると3x3フィルターの上下左右のピクセルの値を保持する必要があり、レジスタ、ラインバッファを使った記述になります。
https://github.com/xiangze/CNN_FPGA/blob/master/template/filter_n2_line.v

リアルタイム超解像処理には全層、全チャネルの並列化が必要です。すべての演算単位を同時に回路化しようとすると回路規模が膨大になってしまい、既存のFPGAには入りきらなくなってしまいます*1。より現実的な構成として以下のようなものが考えられます。

構成のバリエーション案

  • l*n1*n2構成

最初に考えた上記の構成です。画像サイズ*channel数の中間データを保存する必要のない構成ですが、演算単位がl*n1*n2だけ必要になってしまい非現実的です。畳み込みパラメータはハードコーディングすることが出来そうです。
縦横3x3のフィルタを適用するためにラインバッファ、レジスタが必要です。

  • n1*n2構成

filter_n2をl回使いまわす構成です。層ごとの演算結果をメモリに保存し、次の層の計算で読み出します。
1/lだけスループットが低下してしまいます。畳み込みパラメータは別途メモリまたはハードコーディングで用意し、計算している層に応じて切り替えて供給することになります。

  • n2構成

filter_n1をl*n2回使いまわす構成です。スループットは1/(l*n2)になります。
同じ入力で異なる出力を得られるもっとも並列化の恩恵が大きい構成化と思われます。

  • 1構成

channel並列化を行わない場合です。l*n1*n2回繰り返しを行わないといけません。普通のCPUでやったほうがよい構成です。

演算回路の並列化よりもメモリの帯域のほうが重要かもしれないのですがどうでしょう。

その他

メモリへの読み書きと状態を制御する回路を設計するのは面倒かもしれません。AlteraだとをUniPHYを使うのが良いかもしれません。
XilinxのVertexのほうがDSPが多く、より大きなネットワークが実現できるかもしれません。
ReLuは式が簡単なのでバックプロパゲーションFPGAで実施できるかもしれないと思いました。

Reference

waifu2x
http://ultraist.hatenablog.com/entry/2015/05/17/183436
http://waifu2x.udp.jp/
https://github.com/nagadomi/waifu2x

Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang, "Image Super-Resolution Using Deep Convolutional Networks"
http://arxiv.org/pdf/1501.00092v2.pdf

waifu2xのpython実装
https://marcan.st/transf/waifu2x.py
パラメータ
https://marcan.st/transf/scale2.0x_model.json
使用例
http://kiito.hatenablog.com/entry/2015/05/30/181700

サルでも分かる waifu2x のアルゴリズム
https://drive.google.com/file/d/0B22mWPiNr-6-RVVpaGhZa1hJTnM/view

FPGA
CiNii 論文 -  同期シフトデータ転送に基づく Deep Convolutional Neural Network のFPGA実装

Hardware Accelerated Convolutional Neural Networks for Synthetic Vision Systems
http://yann.lecun.com/exdb/publis/pdf/farabet-iscas-10.pdf
Large-Scale FPGA-based Convolutional Networks
http://yann.lecun.com/exdb/publis/pdf/farabet-suml-11.pdf

MicrosoftがFPGAでDeepLearningしてた - SANMAN

Xilinx,Spartanで歩行者検出のデモを行っているものだそうです。ReLUを使用し、学習は別のデバイスで行ったそうです。

RNNでもReLuが使えるという話です。feedbackがあるので初期値が重要だそうです。


[1504.00941] A Simple Way to Initialize Recurrent Networks of Rectified Linear Units

*1:verilog codeとしても100万行近くある非現実的なものが出来上がってしまいました。