ブログのとさか

技術的な話をしたりしなかったり

ChainerでPruning - ニューラルネットの軽量化

「Chainer Pruning」で検索してもすぐにコードが出てこなかったので、実装したついでに簡単な解説記事を書きました。

モデル圧縮

ニューラルネットの研究分野の一つに「モデル圧縮」という物があります。
モデル圧縮では、精度をある程度保ったままニューラルネットのモデルのパラメータ数を削減することで、メモリ使用量を小さくします。場合によっては高速化も目的とします。
ディープニューラルネットのパラメータ数は非常に多く、数MBから数百MB分にもなるため、モデル圧縮の技術はニューラルネットを実用するときに重要になります。

Pruning

pruning(枝刈り)はモデル圧縮の手法の一つで、ニューラルネットの結合重みの一部を0にする(疎行列化する)ことで、パラメータ数を削減します。ニューラルネット中のノードを削減する方のpruningもありますが、この記事では扱いません。
下図はそのイメージです。ここで扱うのは下図の"pruning synapses"の方です。*1
f:id:tosaka2:20171117155016p:plain

どの結合重みを削除するかは一種の組み合わせ最適化問題と捉えることができ、様々な方法が考えられます。
一般的には「重みの絶対値が小さいものを優先的に削除する」というシンプルな手法が用いられ、性能も良いとされています。*2
この手法はmagnitude-based pruningと呼ばれることがあります。この記事で実装しているpruningもこの手法になります。

再訓練

pruningしただけでは精度は落ちてしまいますが、その後再訓練することで精度を取り戻すことができます。
通常pruningは再訓練とセットで行われます。
pruning+再訓練を行ったモデルは、タスクにもよっては精度を落とすこと無くパラメータ数を80%から90%減らすこともできます。以下のグラフは画像キャプション生成を行うモデルに対しての実験です。*3
f:id:tosaka2:20171117155105p:plain

Chainerによる実装

pruning+再訓練をChainer(3.0.0)で実装します。
以下の実装はパラメータを学習することが目的であり、実際にメモリ使用量を削減するには疎行列(テンソル)用の別の実装が必要になることに注意してください。

単に特定の重みを0にしただけでは再訓練時にパラメータが更新され0でなくなってしまうので、pruning時に重み固定用のmask行列を作成し、パラメータが更新される度にpruningされる重みを0に設定し直します。

また、Chainerでは.namedlinks()でChainやLinkが持っているLinkとその名前をセットで取って来ることができるので、それを利用しています。

pruningを実装すれば後はextensionでイテレーション毎に重みを変更するだけで、他は通常どおりです。

使用例です。

この実装ではpruningする層をConvolution2DとLinearに限定しています。変更したい場合はcreate_model_mask関数の

if type(link) not in (L.Convolution2D, L.Linear):

の部分を変更してください。また、少しいじれば.W以外の重みもpruningできるようになると思います。

コード全体はこちら github.com

実験

モデルはVGG16、データセットはCIFAR-100で実験します。*4

以下のグラフはpruning率を40%から90%まででそれぞれ訓練し、最終的なテストデータに対する精度(accuracy)をプロットしたものです。 pruning無しで300 epoch訓練した後、pruningして300 epoch再訓練しています。
f:id:tosaka2:20171214173018p:plain

具体的な数値は [0.6874004602432251, 0.6823248267173767, 0.6668989062309265, 0.6478901505470276, 0.590664803981781, 0.19446656107902527] となっています。

pruning前の精度は0.693869411945343なので、50%までなら1%程度の精度劣化で抑えられることがわかります。

また、どれだけpruningできるかはニューラルネットの構造やデータセットに依存するということも、先の画像キャプション生成での結果との比較からわかります。

なお、pruning率50%において、「訓練→pruning→再訓練」ではなく、「初期の重みによりpruning→訓練」で600 epoch学習した場合の精度は 0.6592356562614441 だったので、このpruning手法が有効であることも簡単にですが確認できました。

応用

  • Iterative Pruning*5
    pruning→再訓練→pruning→再訓練...と繰り返しながらpruningするパラメータ数を増やしていくことで、より多くのパラメータをpruningできるようになります。

  • Dence-Sparse-Dence Training*6
    「普通に訓練」→「pruning+再訓練」→「pruningしたパラメータの0固定を解除し再訓練」という学習手法を適用すると、多くのモデルの性能を少し上げることができます。

*1:引用 https://arxiv.org/abs/1506.02626

*2:https://arxiv.org/abs/1510.00149

*3:引用 http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture15.pdf

*4:細かいパラメータは上述のリポジトリにあるtrain_cifar.pyのデフォルトのものを使っています。再訓練時に学習係数を設定し直す等のことは行っていません。

*5:https://arxiv.org/abs/1506.02626

*6:https://arxiv.org/abs/1607.04381