読者です 読者をやめる 読者になる 読者になる

MXNetでmulti-input/multi-output

皆さんMXNet使っていますか? 年度初に著名データサイエンティストの記事が相次いで盛り上がった感がありましたが、もうChainerなりTensorFlowなりに移ってしまったのでしょうか…

MXNetはDeep Learningフレームワークの比較でドキュメントが弱いことをよく指摘されてるので、1ユーザとして草の根でお役立ち情報を発信していきたいです。

やりたいこと

形式の異なる複数のデータを入力として、複数の値を出力するモデルを学習したい。*1

  • Keras: Functional APIなるものを使って実現できるそうです。

Functional APIのガイド - Keras Documentation

  • Chainer: サポートされている模様。手続き的なフレームワークだと関数に通すだけなので難しいことは少なそう。

Google グループ

  • TensorFlow: サポートされていない模様。MXNetと同様に改造すればいけそう。

How to handle multi tensors input? · Issue #9 · tensorflow/serving · GitHub

参考にしたもの

MXNetで困ったら、まずはexampleを調べることをお勧めします。 英文を含めてもブログ記事の情報量は少なく、Issueには答えが書いていないものも多いです。

私もIssueを読み込みましたが、答えはexampleにありました。

mxnet/example/multi-task at master · dmlc/mxnet · GitHub

サンプルコード

github.com

  • 3ブロック目: 解像度の異なる画像2つと、ベクトル1つをインプットにしています。
x1 = np.zeros((num_train, 1, 8, 8))
x2 = np.zeros((num_train, 1, 16, 32))
x3 = np.zeros((num_train, 10))
y = np.zeros((num_train, num_cls))
z = [y[:,ii] for ii in range(y.shape[1])]
  • 6ブロック目: 公式のDataIterを使うとシンボルの名前が嫌になります。インプットのデフォルト名はdataで、アウトプットのデフォルト名はsoftmax_labelになります。出力層の名前を自分で勝手に決めるとエラーが出るのはこいつが原因です。
print(mx_dat0.provide_data)
print(mx_dat0.provide_label)
  • 8ブロック目: DataIterを改造して、名前を付けやすくしています。

  • 9ブロック目: 適当なネットワークを組みました。

data0 = mx.sym.Variable('input0')
data1 = mx.sym.Variable('input1')
data2 = mx.sym.Variable('input2')
def get_symbol(sym, prefix=''):
    net = mx.sym.BatchNorm(sym,
                           name=prefix+'_bn')
    net = mx.sym.FullyConnected(net,
                                name=prefix+'_fc',
                                num_hidden=3)
    return net
net0 = mx.sym.Flatten(data0)
net1 = mx.sym.Flatten(data1)
net2 = get_symbol(data2, 'i2')
netc  = mx.sym.Concat(*[net0, net1, net2])
fc   = []
out  = []
for ii in range(num_cls):
    fc.append(mx.sym.FullyConnected(netc,
                                    name='fc'+str(ii),
                                    num_hidden=2))
    out.append(mx.sym.SoftmaxOutput(fc[ii],
                                    name='clf'+str(ii)))
net = mx.sym.Group(out)

10ブロック目でプロットしたネットワーク図でinput1input2が見えませんが、インプットのデータを弄るとエラーが出るので、これで大丈夫のはずです。

注意

公式でもmulti-inputやmulti-outputのためのIF改善がTODOに挙げられています。 また、MXNet自体がNNVMとの機能分割で大工事中ですし、将来のバージョンアップで諸々変わってしまう可能性があります。

v1.0 Stable Release TODO List · Issue #2944 · dmlc/mxnet · GitHub

その他

MXNetのLogisticRegressionOutputは名前から想像するような動きをしてくれないので、SoftmaxOutputを使った方が良さそうです。

*1:Deep Learningは人間の脳を模倣した仕組みなので、複数種類のインプット(視覚・聴覚・記憶等)から複数のアウトプット(話す内容・身振り・手振り等)を得るためのルールを学習できます。(嘘)