ブログのとさか

技術的な話をしたりしなかったり

ロジスティック回帰とElo Ratingの関係

はじめに

対戦ゲームのレーティングシステムとして多く採用されているElo Ratingですが, その計算式を見ると内部で行っていることはロジスティック回帰とほとんど一致することがわかります. この記事ではロジスティック回帰とElo Ratingについて簡単に説明し,それらの関係について見ていきます.

また,ついでにこの事実を応用した格闘ゲームのキャラ相性解析のアイデアについて紹介したいと思います.

ロジスティック回帰

ロジスティック回帰は2値分類問題の推論や分析に利用される一般化線形モデルの一つです. ロジスティック回帰ではロジット(対数オッズ)を線形モデルで予測します.*1

このことは予測確率をp,線形モデルの出力をy,ロジスティック回帰の重みベクトルをw,バイアスをb,入力ベクトルをxとした時以下の式で表されます.


y=logit(p)=\log{\dfrac{p}{1-p}}=w^T x+b

予測確率の計算

予測確率pは以下の式で求まります.*2


p=logit^{-1}(y)=\dfrac{1}{1+e^{-y}}

更新式

ロジスティック回帰のパラメータw,bは一般的に最尤推定によって決定します. 最尤推定に勾配法を用いる場合,更新式は予測確率をp,正解ラベルをt,学習率(定数)を\etaとして以下のようになります.


w\leftarrow w - \eta(p-t)x

Elo Rating

Elo Ratingは結果が勝敗で表される対戦競技においてプレイヤーの相対的な実力を表す指標の一つです. チェスや将棋のような古典的なゲームや,スポーツ,デジタルゲームにおいてレーティングシステムとして利用されています. 強さを表す単純な指標としては勝率がありますが,勝率と異なりElo Ratingは対戦相手の強さを考慮して決定する数値のためより正確に実力を評価できます.*3

以下では概要のみ説明して詳しい導出等は行わないので,気になる方は以下のサイトを参考にしてください.

イロ・レーティングの意味と求め方を完全解説 | ワイズ

Elo Ratingは内部レートと表示レートに分けて解釈することができます. 内部レートはレートを計算するための数値で,表示レートは内部レートを人間が解釈しやすいように変換した数値です.*4 これは内部レートをr,表示レートをRとして以下の式で表せます.


R = 400 r + 1500

さて,内部レートrが計算しているものは,平均的強さのプレイヤーと対戦した時の勝率の常用対数オッズです. このことは平均的強さのプレイヤーとの勝率をpとした時以下の式で表せます.


r = \log_{10} \dfrac{p}{1-p}

なお,平均的強さのプレイヤーとは内部レートが0のプレイヤーのことを指します.

勝率の計算

まず,勝敗比は積によって推移するという仮定を置きます. つまり,プレイヤーXがプレイヤーYに勝率する確率をp_{XY}のように表し,プレイヤーX,Y,Zがいる時,以下の式が成り立つということです.


\dfrac{p_{XZ}}{1-p_{XZ}} = \dfrac{p_{XY}}{1-p_{XY}}\dfrac{p_{YZ}}{1-p_{YZ}}

この仮定と内部レートの定義を合わせると,プレイヤーAの内部レートをr_A,プレイヤーBの内部レートをr_Bとした時プレイヤーAの勝率p_{AB}は以下の式で計算できます.*5


p_{AB} =\dfrac{1}{1+10^{r_B-r_A}}

更新式

Elo Ratingの更新式はレートから予想される勝率をp,実際の勝敗をt(勝利時1,敗北時0),適当な定数をKとした時,プレイヤーAの内部レートr_Aの更新は以下の式で行われます.


r_A\leftarrow r_A + K(t_{AB}-p_{AB})

ロジスティック回帰とElo Ratingの関係

さて,ここまで読んだ方はお気づきでしょうが,ロジスティック回帰とElo Ratingの内部レートは対数オッズを計算しているという点で非常によく似ています.

実際ロジスティック回帰においてA,Bを自然数としてプレイヤーAとプレイヤーBの対戦結果p_{AB}を予測するとき,バイアスb=0かつベクトルxを以下のように設定すると,両者はほとんど同じものになります.


x_i = \left\{
\begin{array}{ll}
1 & (i=A)\\
-1 & (i=B)\\
0 & (otherwise)
\end{array}
\right.

ロジスティック回帰の予測勝率は\dfrac{1}{1+e^{w_B-w_A}},Elo Ratingによる予測勝率は\dfrac{1}{1+10^{r_B-r_A}}なので,重みwがレートを表すベクトルであるとみなすと,両者の違いは指数の底(eまたは10)のみです. またロジスティック回帰のとElo Ratingの更新式は予測勝率と勝敗結果が等しく,K=\etaの時完全に一致します.

以上のことをまとめると,(指数の差は定数倍の差でしかないので無視した場合)Elo Ratingが本質的に行なっていることはロジスティック回帰による勝敗予測であり,強さの指標として重みを用いていると解釈できます.*6

格闘ゲームのキャラ相性解析

バイアス0のロジスティック回帰はElo Ratingの拡張としてみなせることがわかりました. このことを利用して格闘ゲームのキャラクター相性の解析とその解釈を行う方法を考えたいと思います. 具体的にはP,Q,X,Yを自然数として,プレイヤーPがキャラクターXを用い,プレイヤーQがキャラクターYを用いて対戦した場合,入力xを以下のように設定することでバイアス0のロジスティック回帰で勝率予測を行います. なお,Cをキャラクターの総数を表す定数とします.


x_i = \left\{
\begin{array}{ll}
1 & (i=X\times C + Y)\\
1 & (i=C\times C + P)\\
-1 & (i=C\times C + Q)\\
0 & (otherwise)
\end{array}
\right.

これらの重みは以下のように解釈できます.

  • w_{C\times C+P}:キャラクター性能と分離したプレイヤーPの強さ
  • w_{C\times C+Q}:キャラクター性能と分離したプレイヤーQの強さ
  • w_{X\times C+ Y}:キャラクターXとキャラクターYの相性

ここで重要なのはキャラクター相性の数値w_{X\times C+ Y}です. この数値の大小や勝率への変換により相性表を作ることも可能ですが,ユーザーがElo Ratingに馴染み深いという前提では以下の変換により相性値を表現することが直感につながると思われます.


M_{XY} = \dfrac{400}{\log_{e} 10}w_{X\times C+ Y}

この式では底の変換公式によりElo Ratingの内部レートと同じスケールに変換し,Elo Ratingと同様に表示レートに変換しています. これにより「同等の実力のプレイヤー同士の場合キャラクターX,Yの相性はレート差M_{XY}分である」という解釈が可能になります.*7

おわりに

この記事ではロジスティック回帰とElo Ratingの関係について説明し,それを応用したアイデアについて説明しました. 実は今回解説した内容について思いついたのはスマブラSPの対戦サイトの統計分析をしようと思ったことがきっかけだったので, 対戦データが手に入り次第実際に説明したアイデアで分析した結果についても紹介したいと思います.

*1:確率ではなく対数オッズを予測する理由は,確率が0から1の値に収まらなければいけないのに対して,対数オッズは-∞から+∞の値を取って良いため線形モデルにとって都合が良いからです.

*2:このロジット関数の逆関数のことをロジスティック関数と言います.NNでお馴染みのシグモイド関数と同じものです.

*3:勝率は実力が高い相手と多く対戦するプレイヤーと実力が低い相手と多く対戦するプレイヤーの強さを比較しようとした場合に不正確な結果を招きます.

*4:内部レートや表示レートという言い方はこの記事で説明しやすいように便宜的に名付けたものであることに注意してください.

*5:ちなみによく強者の指標として用いられるレート2000はレート1500に対して95%程度の勝率となります.

*6:なお,Elo Ratingでは一度の勝敗結果によって一度だけレートを更新するので,勾配法で1エポック更新しているとみなせます.

*7:同じ変換でキャラクター性能と分離したプレイヤーの強さの表示もできますが,この数値については議論の余地があります.

bitsetの便利機能・速度・ランダムアクセスの仕様

APG4b1のビット演算の解説で必要になったので、bitsetの「便利機能」、「実行速度」および「ランダムアクセス方法の違い」について調べてみました。

bitsetの便利機能

bitsetは通常のビット演算では少し面倒な操作を簡単に行えるようになっています。以下に便利機能の表を示します。

記法 動作
[k], .test(k)/.set()(/.reset(k)) ランダムアクセス
.flip(k) k番目のビットを反転
cout << bitsetの値 ビットの形式で標準出力
.to_string() ビットの形式で文字列に変換

bitsetの実行速度

ビット数が少ない場合

1つの整数型で収まる程度にビット数が少ない場合、整数型を直接操作するのと比べて実行速度がどれくらい変わるのかをビット数N=24でシンプルなビット全探索を行うプログラムで調べました。計測に利用したプログラムは記事の末尾に置いておきます。

結果

  • unsigned int:409ms
  • bitset:409ms

どちらも実行速度は変わらないという結果になりました。実質ランダムアクセスのみの比較ではありますが、便利さと速度のトレードオフにはなっておらず、常にbitsetを利用しても問題なさそうです。

ビット数が多い場合

ビット数が多い場合に関しては既に比較している記事があったので、リンクと概要を紹介します。

  • ビット演算の速度比較
    • ビット数1000000の場合、シフト演算はarray<bool, N>vector<char>等の他のコンテナと比べて2倍~3000倍速く、OR演算は190倍~1000倍速い。
  • ランダムアクセスの比較
    • bool[]よりは3倍近く遅いが、vector<bool>よりは少し速い。速度が重要であり、bitsetの便利関数やビット演算を行わない場合はbool[]を使ったほうが良い。

ランダムアクセス方法の違い

bitsetはk番目のビットを操作する場合、.test(k)/.set(k, b)または[k]を用います。動作の違いを以下の表に示します。

記法 動作 境界チェック
.test(k) k番目のビットの読み込み あり
.set(k, b) k番目のビットをbに変更 あり
[k] k番目のビットへの読み書き なし

[k]は便利ですが境界チェックがなく、kが境界外のときは未定義動作になります。しかもその検出は厄介なことがあるので、複雑な添字操作を行う場合は[]よりもtestsetを利用したほうが良いでしょう。

検出が厄介な理由

GCCの実装を見るとbitsetはスタック領域上の配列でビット列を管理しており、operator[]では間接的にその配列へアクセスしています。そのためoperator[]を利用して境界外アクセスが発生した場合でも、もともとのビット数が十分多ければAddressSanitizerやgdbでエラーを検知できます。
ただし、ビット数がunsigned longひとつで表現できる場合(通常64以下)は特殊化されており、単にunsigned longを扱っているのと同じような実装になります。そのため不正な添字でもランダムアクセスの際には単なるビット演算しか行われずSegmentation Faultが発生しないので、ツールが検知してくれません。

結論

bitsetは速い・安全・便利。C++でビット演算を行うときは毎回bitsetを利用してもいいかもしれない。
ただし安全に使いたいなら.test(k)/.set(k, b)を使うべし。

計測に使ったプログラム

計測に使ったプログラムは以下のよいうになります。
実行時間はAtCoderのコードテストC++14 (GCC5.4.1)上で複数回実行した結果の平均を取っています。

// integer
#include <bits/stdc++.h>
using namespace std;

int main() {
  const int N = 24;
  unsigned int sum = 0;
  for (int i = 0; i < (1 << N); i++) {
    unsigned int bit = i;
    for (int i = 0; i < N; ++i) {
      if (bit & (1 << i)) {
        sum += i;
      }
    }
  }
  cout << sum << endl;
}
// bitset
#include <bits/stdc++.h>
using namespace std;

int main() {
  const int N = 24;
  unsigned int sum = 0;
  for (int i = 0; i < (1 << N); i++) {
    bitset<N> bit(i);
    for (int i = 0; i < N; ++i) {
      if (bit.test(i)) {
        sum += i;
      }
    }
  }
  cout << sum << endl;
}

  1. 僕とあるごんさんで書いてるC++の入門教材

GAでDNNのPruningしてみた

大学で学部3年生以前から研究らしきことができる実験科目を履修し,「遺伝的アルゴリズムを用いたディープニューラルネットワークのモデル圧縮」というテーマに取り組みました. 1月ごろに取り組んでいたものではありますが,書いたレポートを公開したいと思います.

注意

このブログ記事はあくまで取り組んだことの記録として書いています. なにか有用な情報が書いてあるわけではありません.

レポート

このレポートは提出時にわかっていた情報で書きましたが,このレポートで主張している有用性は後の実験で否定されています.

レポートのリンク

実験で用いた技術

この実験ではPython, Chainer, Cupyを用いました.特にGAの計算はGPUで行わないと非常に時間がかかるので,Cupyにとても助けられました.  

ソースコードは公開準備中です...*1

再学習を行うとこの手法は無意味

このレポートではPruning後に再学習をしない設定で精度を比較していました. 元々「手持ちのデータが少ない」という問題設定だったので,再学習をしても精度はそれほど変わらない/悪化してしまうだろうと考えてのことでした.

しかしそれは完全に間違いでした.実際に再学習を行ってみると,Pruning率80%で再学習に各クラス1枚のみを用いた場合では,Test Accuracyは80%程度まで上昇しました.
再学習前のTest AccuracyがMagnitude BasedでもGAでも60%後半であることを考えると,これは無視できない影響の大きさです.

しかも,どちらの手法でPruningしても,再学習後のTest Accuracyはほとんど変わりませんでした.*2

この手法の発展

上述の実験から,データが少なくても再学習を考慮する必要があることがわかりました.

レポートでも紹介しているIterative Pruningを参考にすると,「GA Pruning→再学習→GA Pruning→...」と繰り返していく手法が考えられます.

ここで,Pruningに用いているマスク行列をハイパーパラメータのようなものだと考えると,GA Pruningと再学習を繰り返す手法はPopulation Based Trainingに類似した手法であると考えることができます.*3

また,GAで最適化するものをマスク行列ではなく,各層の計算方法などにした場合,アーキテクチャ探索系の手法になります.


補足:この実験に対する批判

再学習関連が最も重要な情報ですが,それ以外にもこのレポートの手法には様々な批判があると思うので,一応補足しておきます.

  1. 実験が粗い
  2. 実際には「精度が過度に落ちた状態でよりマシな精度になる」手法は使えない
  3. GAのチューニングが足りない(GAの集団を同一のものにするべきではない)
  4. 他のメタヒューリスティクス系の手法と比較するべき

1点目についてはその通りだとしか言えません.実験結果が出るまで多少時間がかかるので,レポート提出期限までに十分な回数実験を行えませんでした...*4

2点目について,実際に使える手法を開発する場合は「一定の精度を達成することを条件として,よりパラメータ数が少ないモデルを探索」という問題設定にするべきだと思います.その点でこのレポートはあまり良くないです.

3点目について,初期集団に関して「重みが大きいものが指数的に選ばれやすい」というようなものに関しても実験しましたが,あまり上手く行きませんでした.*5

4点目について,SAを用いたPruningについても軽く実験しましたが,あまりうまくいきませんでした. 工夫すれば上手くいく方法もあるかもしれませんが,良い近傍を模索する必要があると感じたので,今回は素直に適用できるGAのみの結果をレポートにしました.
なお,GAがランダムサーチより優れていることは確認しました.

*1:大学の研究室の計算機をお借りして作業しており,諸事情ありソースコードがその計算機上にしか無いのですが,学期が変わった関係かその計算機にアクセスする権限が無くなっていました...GW明けに再度権限をいただけるようお願いしてみます...

*2:再学習用のデータが少ない場合,どのデータが再学習用データに選ばれるかによってAccuracyは±5%程度変動します.しっかりとデータを取ると平均的にはどちらか一方の手法が優れている可能性はありますが,パラメータを変えて色々と試していた限りでは,はっきりとした有意差は見られませんでした.

*3:PBTはGAというより進化戦略のような方法ですが

*4:Deep系実験あるあるかもしれません...

*5:正直これに関してはバグがある可能性はあります...

m項間漸化式の第n項までの和を$O(m ^ 2 log n)$で

この記事ではm項間漸化式の第n項までの和をO (m ^ 2 log n)で求める方法について説明します。

@mt_caret がnth Fibonacci number in O(logn)という記事を書いていたのを見て、以前ブログに書こうと思っていて完全に忘却していたネタを思い出したので書きました。

イントロ

m項間漸化式の第nO (m ^ 3 log n)

m項間の線形漸化式*1
{ \displaystyle
a_{n+m} = \sum_{i=0}^{m-1} c_i a_{n+i}
}

は、次のように行列で表現し、繰り返し二乗法を使ってn乗を計算することでO (m ^ 3 log n)で求めることができます。

$$A= \begin{pmatrix} c_{m -1} & \cdots & c_{1} & c_{0} \\ 1 & \cdots & 0 & 0 \\ \vdots & \ddots & \vdots & \vdots \\ 0 & \cdots & 1 & 0 \end{pmatrix} $$

$$ \begin{pmatrix} a_{n+m -1} \\ a_{n+m -2} \\ \vdots \\ a_{n} \\ \end{pmatrix} =A \begin{pmatrix} a_{n+m -2} \\ a_{n+m -3} \\ \vdots \\ a_{n -1} \\ \end{pmatrix} =A ^ n \begin{pmatrix} a_{m -1} \\ a_{m -2} \\ \vdots \\ a_{0} \\ \end{pmatrix} $$

このことは、はじめに紹介した記事や蟻本*2で説明されています。

m項間漸化式の第n項までの和 O (m ^ 3 log n)

似たような方法で第n項までの和を求めることもできます。

$$ S_i = I + A + \ldots + A ^ {i -1} $$

とすると、ブロック行列を用いて

$$ \begin{pmatrix} A ^ n \\ S_n \\ \end{pmatrix} =\begin{pmatrix} A & 0 \\ I & I \\ \end{pmatrix}\begin{pmatrix} A ^ {n -1} \\ S_{n -1} \\ \end{pmatrix} =\begin{pmatrix} A & 0 \\ I & I \\ \end{pmatrix} ^ n \begin{pmatrix} I \\ 0 \\ \end{pmatrix} $$

と書くことができるので、繰り返し二乗法を適用するとO (m ^ 3 log n)S_{n+1}が求められます。

後は S_{n+1}
\begin{pmatrix}
a_{m-1}, a_{m-2} , \ldots , a_0
\end{pmatrix} ^ {\mathrm{T}}
を掛ければ、第n項までの和a_{0} + a_{1} + \ldots + a_{n}が求まります。

$$ S_{n+1} \begin{pmatrix} a_{m -1} \\ a_{m -2} \\ \vdots \\ a_{0} \\ \end{pmatrix} = ( I + A + \ldots + A ^ {i -1} + A ^ i ) \begin{pmatrix} a_{m -1} \\ a_{m -2} \\ \vdots \\ a_{0} \\ \end{pmatrix} $$

$$ = \begin{pmatrix} a_{m -1} \\ a_{m -2} \\ \vdots \\ a_{0} \\ \end{pmatrix} + \begin{pmatrix} a_{m} \\ a_{m -1} \\ \vdots \\ a_{1} \\ \end{pmatrix} + \begin{pmatrix} a_{n+m -1} \\ a_{n+m -2} \\ \vdots \\ a_{n} \\ \end{pmatrix} $$

$$ = \begin{pmatrix} a_{m -1} + a_{m -2} + \ldots + a_{n+m -1} \\ a_{m -2} + a_{m -3} + \ldots + a_{n+m -2} \\ \vdots \\ a_{0} + a_{1} + \ldots + a_{n} \\ \end{pmatrix} $$

このことは蟻本で行列の累乗和を求める方法として説明されています。

kitamasa法によるm項間漸化式の第n項の計算 O (m ^ 2 log n)

m項間漸化式の第n項はより高速に計算する方法があります。
日本の競技プログラミング界隈ではkitamasa法と呼ばれます。*3

ここでは詳しい説明はしません。
気になる人はm項間漸化式の高速なアルゴリズム - 競技プログラミングをするんだよ高速 Kitamasa 法 - みさわめも等のブログ記事を読んで下さい。

ここで重要な点は、kitamasa法を用いるとm項間漸化式の第n項がO (m ^ 2 log n)で計算できるという点です。*4

本題

m項間漸化式の第n項までの和 O (m ^ 2 log n)

m項間漸化式の第n項までの和は、行列の累乗和を陽に求めなくても計算できます。

イデアは簡単で、第n項までの総和を

{ \displaystyle
s_{n} = \sum_{i=0}^{n} a_{n}
}

として、m項間漸化式から、総和の漸化式

{ \displaystyle
s_{n+m'} = \sum_{i=0}^{m'} c'_i s_{n+i}
}

を構成することを考えます。

結論だけ言うと、次のような関係になります。*5

m'=m+1つまりm+1項間漸化式s_{n+m+1}について、

c'_0=-c_0
c'_{m}=2 c_{m-1}
c'_i=c_i - c_{i+1}

この関係は{ \displaystyle
a_{n+m} = \sum_{i=0}^{m-1} c_i a_{n+i}
}s_{n+1}=s_{n}+a_{n+1}から導くことができるので、計算してみてください。

これにより、m項間漸化式の第n項までの和を求める問題を、m+1項間漸化式の第n項を求める問題に変換できました。

すでに説明したとおり、m項間漸化式の第n項はkitamasa法によりO (m ^ 2 log n)で求めることができるので、m項間漸化式の第n項までの和もO (m ^ 2 log n)で求められるようになりました。*6

実験

m項間漸化式の第n項までの和を求める計算を以下の3つの方法で行い、実行速度を比較しました。

  • 前から順番に求めて足していく O (m n)
  • 行列の累乗和 O (m ^ 3 log n)
  • 総和の漸化式+kitamasa法 O (m ^ 2 log n)

言語はC++です。答えは非常に大きくなりうるので10 ^ 9 + 7でmodを取っています*7。その他の詳細な設定はリンク先で確認してください。

kitamasa法のソースコードこちらからお借りしました。
累乗和の実装は蟻本を参考にしています。

結果は以下の表のようになりました。

$$m=3,n=10 ^ 6$$ $$m=10,n=10 ^ 6$$ $$m=100,n=10 ^ 8$$
前から足していく 21.5772ms 75.9941 ms -(時間がかかりすぎるので省略)
行列の累乗和 0.024ms 0.6107 ms 1764.7 ms
総和の漸化式+kitamasa 0.0153ms 0.0991 ms 12.0769 ms
詳細 https://wandbox.org/permlink/uVmogwgBVcEINwYf https://wandbox.org/permlink/6HbnojdzEU61JMgP https://wandbox.org/permlink/GYXVyPbczVkOEen4

粗い測定ですが、mが小さいときから大きいときまで、一貫して総和の漸化式+kitamasa法が高速であることが確認できました。

一応gistにもソースコードを上げておきます。ブログの説明とは細かい添字が違ったりするので注意してください。

https://gist.github.com/tosaka2/e1b4e9dcb7892568d16570075d85941a

余談 O (m + n)

c_{i}=m - iである場合、総和の漸化式への変換を使ってm項間漸化式の第n項をO (m + n)で求めることができます。

これに限らず係数がある条件を満たす場合は、総和を経由することにより、第n項をO (m + n)で求められます。*8

このアルゴリズムm \leq 10 ^ 5, n \leq 10 ^ 6のように、 mが大きく、 m nのスケールが近いとき、kitamasa法を使うよりも高速な解法になり得ます。

所感

コンテストでこれが必要になることは無いと思うので、まぁ小ネタです。
息抜きに書くつもりがはてなブログの数式の仕様と格闘していたら思いの外時間がかかってしまいました...

何か間違い等あればご指摘ください。

*1:この書き方だとm+1項間漸化式な気もしますが、蟻本に合わせてこの式をm項間漸化式と呼ぶことにします。

*2:プログラミングコンテストチャレンジブック 第2版 P.180~

*3:コンパニオン行列のべき乗

*4:高速kitamasa法を用いるとO (m log m log n)まで計算量が落ちます

*5:はてなで上手く数式がかけないのでこの書き方になりました...

*6:c'_0, c'_1, \ldots, c'_{m}s_0 , s_1 , \ldots , s_{m}を前処理で計算する必要がありますが、これはどちらもO (m)で計算できます

*7:特に総和の漸化式+kitamasa法では係数が負になるので、元の数列の係数が全て正でも負のmodの処理をしてあげる必要があります

*8:総和に変換→ O(1)で次の項が求められる式に変形(これができることが条件)→s_0からs_{n}まで順に求める→a_{n}=s_{n}-s_{n -1}n項目が求まる

AtCoder Programming Guide for beginners 公開開始!

この記事は Competitive Programming Advent Calendar 2017 - Adventar 12/20の記事です。(少し遅刻しました。)

みなさん1年前のこの記事を覚えているでしょうか。

tosaka2.hatenablog.com

あれから何の音沙汰も無かったこの件ですが、ついに今日から公開になりました!!!!!!

AtCoder Programming Guide for beginners (APG4b) - AtCoder

といっても順次公開という形で、現在見られるのはトップページと「はじめに」だけになります。
説明はもっと先まで書き終わっていますが、調整しながら少しずつ公開していきます。

ドシドシご意見ご要望お待ちしておりますので、みなさんよろしくお願いします!

ご意見ご要望は公式(?)twitterアカウントに送ってください↓

twitter.com

(エアリプだと気付かないので、できるだけリプかDMで送ってね)

ChainerでPruning - ニューラルネットの軽量化

「Chainer Pruning」で検索してもすぐにコードが出てこなかったので、実装したついでに簡単な解説記事を書きました。

モデル圧縮

ニューラルネットの研究分野の一つに「モデル圧縮」という物があります。
モデル圧縮では、精度をある程度保ったままニューラルネットのモデルのパラメータ数を削減することで、メモリ使用量を小さくします。場合によっては高速化も目的とします。
ディープニューラルネットのパラメータ数は非常に多く、数MBから数百MB分にもなるため、モデル圧縮の技術はニューラルネットを実用するときに重要になります。

Pruning

pruning(枝刈り)はモデル圧縮の手法の一つで、ニューラルネットの結合重みの一部を0にする(疎行列化する)ことで、パラメータ数を削減します。ニューラルネット中のノードを削減する方のpruningもありますが、この記事では扱いません。
下図はそのイメージです。ここで扱うのは下図の"pruning synapses"の方です。*1
f:id:tosaka2:20171117155016p:plain

どの結合重みを削除するかは一種の組み合わせ最適化問題と捉えることができ、様々な方法が考えられます。
一般的には「重みの絶対値が小さいものを優先的に削除する」というシンプルな手法が用いられ、性能も良いとされています。*2
この手法はmagnitude-based pruningと呼ばれることがあります。この記事で実装しているpruningもこの手法になります。

再訓練

pruningしただけでは精度は落ちてしまいますが、その後再訓練することで精度を取り戻すことができます。
通常pruningは再訓練とセットで行われます。
pruning+再訓練を行ったモデルは、タスクにもよっては精度を落とすこと無くパラメータ数を80%から90%減らすこともできます。以下のグラフは画像キャプション生成を行うモデルに対しての実験です。*3
f:id:tosaka2:20171117155105p:plain

Chainerによる実装

pruning+再訓練をChainer(3.0.0)で実装します。
以下の実装はパラメータを学習することが目的であり、実際にメモリ使用量を削減するには疎行列(テンソル)用の別の実装が必要になることに注意してください。

単に特定の重みを0にしただけでは再訓練時にパラメータが更新され0でなくなってしまうので、pruning時に重み固定用のmask行列を作成し、パラメータが更新される度にpruningされる重みを0に設定し直します。

また、Chainerでは.namedlinks()でChainやLinkが持っているLinkとその名前をセットで取って来ることができるので、それを利用しています。

pruningを実装すれば後はextensionでイテレーション毎に重みを変更するだけで、他は通常どおりです。

使用例です。

この実装ではpruningする層をConvolution2DとLinearに限定しています。変更したい場合はcreate_model_mask関数の

if type(link) not in (L.Convolution2D, L.Linear):

の部分を変更してください。また、少しいじれば.W以外の重みもpruningできるようになると思います。

コード全体はこちら github.com

実験

モデルはVGG16、データセットはCIFAR-100で実験します。*4

以下のグラフはpruning率を40%から90%まででそれぞれ訓練し、最終的なテストデータに対する精度(accuracy)をプロットしたものです。 pruning無しで300 epoch訓練した後、pruningして300 epoch再訓練しています。
f:id:tosaka2:20171214173018p:plain

具体的な数値は [0.6874004602432251, 0.6823248267173767, 0.6668989062309265, 0.6478901505470276, 0.590664803981781, 0.19446656107902527] となっています。

pruning前の精度は0.693869411945343なので、50%までなら1%程度の精度劣化で抑えられることがわかります。

また、どれだけpruningできるかはニューラルネットの構造やデータセットに依存するということも、先の画像キャプション生成での結果との比較からわかります。

なお、pruning率50%において、「訓練→pruning→再訓練」ではなく、「初期の重みによりpruning→訓練」で600 epoch学習した場合の精度は 0.6592356562614441 だったので、このpruning手法が有効であることも簡単にですが確認できました。

応用

  • Iterative Pruning*5
    pruning→再訓練→pruning→再訓練...と繰り返しながらpruningするパラメータ数を増やしていくことで、より多くのパラメータをpruningできるようになります。

  • Dence-Sparse-Dence Training*6
    「普通に訓練」→「pruning+再訓練」→「pruningしたパラメータの0固定を解除し再訓練」という学習手法を適用すると、多くのモデルの性能を少し上げることができます。

*1:引用 https://arxiv.org/abs/1506.02626

*2:https://arxiv.org/abs/1510.00149

*3:引用 http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture15.pdf

*4:細かいパラメータは上述のリポジトリにあるtrain_cifar.pyのデフォルトのものを使っています。再訓練時に学習係数を設定し直す等のことは行っていません。

*5:https://arxiv.org/abs/1506.02626

*6:https://arxiv.org/abs/1607.04381

#MakeGirlsMoe でもっと遊べるプログラムを書いた(実質百合画像生成)

MakeGirls.moe

二次元キャラを自動生成してくれるMakeGirls.moeが今話題になっています。
f:id:tosaka2:20170815164557p:plain

make.girls.moe

今までの画像生成手法と比べてもかなり綺麗な画像が出力されるので驚きました。
PFNでアルバイトしている方が作ったそうです。流石...

ネットの評判(ニコニコの某動画)を見てみると、「ただパーツを組み合わせただけだろ」とか、「データベースからランダムに選んでるだけでしょ」といったコメントをしている人がいて、 いかに高度な画像生成ができているかがわかります。
技術的な詳細は公式ブログにまとまっています。

MakeGirls.moe Official Blog

ここでGANとは~~みたいな話をしてもいいのですが、今回はこのWebサービスでもっと遊ぶためのプログラムを書いたので導入方法と使い方を紹介します。

このプログラムを導入すると「2つのイラストの中間のイラスト」を出力することができます。
イメージとしては以下のツイートの通り。(実質百合では?)

導入方法

画像で説明していきますが、この画像を作ったときより少しアップデートされていて微妙に差があります。
Google Chromeでしか動作確認していません。MakeGirls.moeのアップデートですぐに動かなくなるかもしれないので注意してください。

工程1

※同じ階層にあるmain(なんちゃら).jsを開けばOKです。
f:id:tosaka2:20170815165154p:plain

工程2 (8/16 編集)

プログラムを更新したので画像と説明が少し異なります。

24661行目ではなく、return t.generate()から始まる行の左の数字をクリックしたください。(2017/11/25 現在 31760行目)
Ctrl+Fを押してreturn e.generate()で検索すれば一箇所だけヒットすると思います。
また、Consoleが表示されてない人は一度Consoleタブに切り替えても良いです。(Consoleタブの場所は工程3の画像に書いてあります。)
f:id:tosaka2:20170815165201p:plain
以下のプログラムをコピペしてください。

// return e.generate()の行にブレークポイント
var getObject = () => e;

var getState = () => getObject().state;
var getOption = () => getObject().getOptions();
var getNoise = () => getState().gan.noise;
var getNoiseOrigin = () => getState().gan.noiseOrigin;
var isRunning = () => getState().gan.isRunning;

// NoiseをFixedに
var fixNoise = () => getOption().noise.random = false;
// NoiseをRandomに
var randomizeNoise = () => getOption().noise.random = true;

var setRandomOption = (op, random) => {
    for (let param in op) {
        if (op[param]["random"] !== undefined) {
            op[param].random = random;
        }
    }
}

var fixOption = (op = null) => setRandomOption(op || getOption(), false);
var randomizeOption = (op = null) => setRandomOption(op || getOption(), true);

// Noiseを出力
var printNoise = () => "[" + getNoise().join(',') + "]";

// https://github.com/makegirlsmoe/makegirls.moe_web/blob/22272ae7fad6ef3a24a463ea497432a3b6913ead/src/utils/ImageEncoder.js
// GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007 https://github.com/makegirlsmoe/makegirls.moe_web/blob/master/LICENSE.txt
var encodeNoiseOrigin = noiseOrigin => {
    let canvas = document.createElement('canvas');
    let canvasWidth = 128;
    let canvasHeight = 34;
    canvas.width = canvasWidth;
    canvas.height = canvasHeight;
    let ctx = canvas.getContext("2d");
    let canvasData = ctx.getImageData(0, 0, canvasWidth, canvasHeight);

    function drawLine(x, color) {
        for (var i = x * 4; i < canvasData.data.length; i += canvasWidth * 4) {
            canvasData.data[i] = color.r;
            canvasData.data[i + 1] = color.g;
            canvasData.data[i + 2] = color.b;
            canvasData.data[i + 3] = color.a;
        }
    }

    function getColor(x) {
        return {
            r: 255,
            g: Math.floor((1 - x[1]) * 256),
            b: Math.floor((1 - x[0]) * 256),
            a: 254
        };
    }

    function updateCanvas() {
        ctx.putImageData(canvasData, 0, 0);
    }

    for (let i = 0; i < canvasWidth; i++) {
        drawLine(i, getColor(noiseOrigin[i]));
    }

    updateCanvas();

    return canvas.toDataURL();
}

var setNoiseOrigin = a => {
    if (isRunning()) return;
    // setNoiseOriginは無くなってる
    document.querySelector('.noise-canvas').firstElementChild.src = encodeNoiseOrigin(a);
    
    cn = getState().gan.noiseOrigin;
    for (let i = 0; i < cn.length; i++)
        for (let j = 0; j < cn[i].length; j++)
            cn[i][j] = a[i][j];

    // 上と下のどちらかで良い?
    //getOption().noise.value = a;
}

// Noiseを設定
var setNoise = a => { 
    if (isRunning()) return; 
    setNoiseOrigin(noiseToNoiseOrigin(a));
};

var setOption = op => {
    if (isRunning()) return;
    Object.assign(getOption(), op);
}

var cloneOption = op => {
    let obj = {};
    for (let param in op) {

        obj[param] = op[param]["random"] !== undefined
            ? Object.assign({}, op[param])
            : op[param];
            
        if (param === "noise") {
            arr = op.noise.value.map(x => x.map(y => y));
            obj.noise.value = arr;
        }
    }
    return obj;
}

// 中断フラグ
var _isAborted = false;
var _selected = [{option:null, img:null},{option:null, img:null}];

// Generate
var generate = async () => { 
    if (isRunning()) return;
    await getObject().generate();
};
// NoiseをRandomにしてGenerate
var generateByRandom = async () => { randomizeNoise(); await generate(); };
// 引数で指定したNoiseでGenerate
var generateBy = async a => { fixNoise(); setNoise(a); await generate() };
var cancel = () => { _isAborted = true; };

// ベクトルの操作
var add = (a, b) => a.map((x, i) => x + b[i]);
var sub = (a, b) => a.map((x, i) => x - b[i]);
var times = (a, t) => a.map((x, i) => x * t);
var norm = a => Math.sqrt(a.reduce((s, a) => s + a**2))
var interpolate = (a, b, p) => Array.isArray(a)
    ? (a.map((x, i) => x * (1-p) + b[i] * p))
    : (a * (1-p) + b * p);

// そこそこ誤差あるけど見てわかるほど影響は出無さそう?
var calculatebackToPixel = v => {
    let b = 255;
    let u = Math.sqrt( -2.0 * Math.log( 1 - b/ 256) );
    let tmp = v / u;

    // 大きすぎるノイズのとき
    if (Math.abs(tmp) > 1) tmp = Math.sign(tmp);

    let pg = (1 - (Math.acos(tmp) / (2 * Math.PI))) * 256;
    let g = Math.min(Math.floor(pg + 0.5), 255); //四捨五入
    return [b, g];
}

var noiseToPixels = a => a.map(calculatebackToPixel);
var pixelToNoiseOrigin = ps => ps.map(x => [1 - x[0]/256, 1 - x[1]/256]);
var noiseToNoiseOrigin = a => pixelToNoiseOrigin(noiseToPixels(a));
var noiseOriginToNoise = noiseOrigin =>
    noiseOrigin.map(([u, v]) => Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v));

var downloadImage = (img, name) => {
    img = img || document.body.querySelector(".result-canvas").firstChild;
    let a = document.createElement("a");
    a.href = img.src;
    a.target = "_blank";
    a.download = name;
    a.click();
}

// 生成済みの画像を下に表示
var addImg = (src, option) => {
    let results = document.body.querySelector(".imgs");
    if (!results) {
        results = document.createElement("div");
        results.className = "row imgs";
        document.body.querySelector(".App").appendChild(results);
    }
    
    let img = document.createElement("img");
    
    img.src = src;
    let op = cloneOption(option);
    
    img.onclick = () => {
        console.log(op);
        console.log(img);
        if (_selected[1].img && _selected[0].img != _selected[1].img) {
            _selected[1].img.style.border = "none";
        }
        
        //更新
        _selected = [{option: op, img: img}, _selected[0]];
        
        if (_selected[1].img){
            _selected[1].img.style.border = "dashed";
        }
        _selected[0].img.style.border = "solid";

        if (_selected[1].img && _selected[0].img == _selected[1].img) {
            downloadImage(img, "mgm.png");
        }
    }

    results.insertBefore(img, results.firstChild);
    return img;
}

// 下に表示してある画像をクリア
var clearImgs = () => {
    let imgs = document.body.querySelector(".imgs");
    for (let x of imgs.children) {
        x.onclick = null;
        x.src = "";
    }
    imgs.parentNode.removeChild(imgs);
}

// 下に表示してある画像を一枚消す
var deleteImg = img => {
    img.onclick = null;
    img.src = "";
    img.parentNode.removeChild(img);
}

// RandomなNoiseでn枚画像生成し,下に表示.
var generateRandomImages = async n => {
    if (isRunning()) return;
    let result = document.body.querySelector(".result-canvas").firstChild;

    for (let i = 0; i <= n - 1; i++) {
        if (_isAborted) {
            _isAborted = false;
            break;
        }
        await generateByRandom();
        addImg(result.src, getOption());
    }
}

// 2つのオプションの内分点を計算
var interpolateOption = (op1, op2, p) => {
    let obj = { };
    for (let param in op1) {
        if (param === "amount" || param === "currentModel") {
            obj[param] = op1[param];
        }
        else if (param === "noise") {
            let a = noiseOriginToNoise(op1.noise.value);
            let b = noiseOriginToNoise(op2.noise.value);
            let newNoise = interpolate(a, b, p);

            obj.noise = {random: false, value: noiseToNoiseOrigin(newNoise)};
            continue;
        }
        else {
            obj[param] = {random: false, value: interpolate(op1[param].value, op2[param].value, p)};
        }
    }
    
    return obj;
};

// オプションレベルの補完画像を生成
var generateInterpolations = async (op1, op2, n) => {
    if (isRunning()) return;
    let result = document.body.querySelector(".result-canvas").firstChild;
    let moto = cloneOption(getOption());

    for(let i = 0; i < n; i++) {
        if (_isAborted) {
            _isAborted = false;
            break;
        }
        let op = interpolateOption(op1, op2, i / (n - 1));
        setOption(op);
        setNoiseOrigin(op.noise.value);
        
        await generate();
        addImg(result.src, op);
    }
    
    setOption(moto);
}

// ボタン追加処理
var addButton = (text, func) => {
    let buttons = document.body.querySelector(".exbtns");
    if (!buttons) {
        buttons = document.createElement("div");
        buttons.className = "row exbtns";
        document.body.querySelector(".options-container").lastChild.appendChild(buttons);
    }
    let b = document.body.querySelector(".btn-primary").cloneNode();
    b.textContent = text;
    b.onclick = func;
    buttons.appendChild(b);
}

(() => {
    let buttons = document.body.querySelector(".exbtns");
    if (buttons) {
        buttons.parentNode.removeChild(buttons);
    }
    
    addButton("生成10", () => generateRandomImages(10));
    addButton("100", () => generateRandomImages(100));
    addButton("1000", () => generateRandomImages(1000));
    addButton("∞", () => generateRandomImages(100000000000000000));
    addButton("補間", () => generateInterpolations(cloneOption(_selected[0].option), cloneOption(_selected[1].option), 10));
    addButton("百合", () => generateInterpolations(cloneOption(_selected[0].option), cloneOption(_selected[1].option), 3));
    addButton("中断", () => { cancel() });
    // isRunning弾かないとgetNoise()でバグる
    addButton("下へ", () => {
        if (isRunning()) return;
        addImg(document.body.querySelector(".result-canvas").firstChild.src, getOption());
    });
    addButton("上へ", async () => {
        let op = cloneOption(_selected[0].option);
        fixOption(op);
        setOption(op);
        await generate();
    });
    // 選択したやつなのか、表示してる画像なのかわかりにくいけど許して
    addButton("口パク", () => {
        let op = _selected[0].option;
        if (!op) op = getOption();
        let op1 = cloneOption(op);
        let op2 = cloneOption(op);
        op2.open_mouth = {random: false, value: 1};
        op1.open_mouth = {random: false, value: -1};
        generateInterpolations(op1, op2, 10);
    });
    
    addButton("笑顔", () => {
        let op = _selected[0].option;
        if (!op) op = getOption();
        let op1 = cloneOption(op);
        let op2 = cloneOption(op);
        op2.smile = {random: false, value: 2};
        op1.smile = {random: false, value: -1};
        generateInterpolations(op1, op2, 10);
    });

    addButton("照れ", () => {
        let op = _selected[0].option;
        if (!op) op = getOption();
        let op1 = cloneOption(op);
        let op2 = cloneOption(op);
        op1.open_mouth = {random: false, value: 0};
        op1.blush = {random: false, value: -1};
        op2.open_mouth = {random: false, value: 1};
        op2.blush = {random: false, value: 2};
        generateInterpolations(op1, op2, 10);
    });

    addButton("クリア", () => clearImgs());
    addButton("1枚削除", () => {
        if (_selected[0].img) deleteImg(_selected[0].img);
        _selected[0] = _selected[1];
    });

})();

工程3

同じく24661行目ではなくなっていますが、前の工程でクリックした場所と同じ場所をクリックしてください。
f:id:tosaka2:20170815165137p:plain

工程4

最後の工程です。
この画像を作ったときから少しプログラムを変えたのでボタンの数が増えています。 f:id:tosaka2:20170815165145p:plain

画像が重複して表示されてしまう人(8/16 追記)

プログラム中のwait_sec = 7となっている場所(2箇所あり)を変えてください。
これは生成を何秒待つ必要があるかを秒単位で指定するパラメータです。手元の実行環境の生成速度に合わせてデフォルトで7秒にしていますが、これより生成が速い場合は短く、遅い場合は長くしてください。

ちなみにプログラムを更新するときははじめからやり直す必要はなく、最初の2行だけを飛ばしてConsoleタブにコピペすれば大丈夫です。

この問題は解決しました(8/16)

既知のバグ

- 「補間」で生成した画像を元に補間ができない。(ノイズが正しく_tmpsに入っていない。)→多分直った(8/16)

その他の機能

「中断」を押すと画像の生成を中断できます。
生成した画像の中から好きな画像を2枚選んで「補間」を押すと10枚の間のイラストを生成します。
f:id:tosaka2:20170815180203j:plain

「複製」は元の画像表示領域にある画像を下に表示させます。
「クリア」は下の画像領域の画像をすべて消します。
「百合」は選択した2つのイラストの中点の画像を1枚だけ生成します。(左右に元画像を表示します。)

「Hat x2」等ではHatオプションを強調できます。通常通りHat Onにしただけでは上手く生成できなかったイラストにも、これを押してから生成すればはっきり帽子が出ることが多いです。
公式Expert Modeの追加により削除しました。
- Hat On
f:id:tosaka2:20170818173539j:plain
- Hat x3
f:id:tosaka2:20170818173519j:plain

「下へ」を押すと下の画像表示領域に画像が複製されます。 「上へ」を押すと上の元の画像表示場所に画像が再生成されます。(「Current Noise更新」はここに吸収されました。)

下の画像をダブルクリックすると画像がダウンロードできます。

また、Consoleタブからプログラムを入力すれば他にも色々な操作が可能です。詳しくは上記プログラムを見てみてください。(適当なJavaScriptの知識で書いてあるのは許して)

生成した画像

とりあえず100枚生成して気に入った2枚を選んで補間すると良いのが出たりします。
f:id:tosaka2:20170815173205p:plainf:id:tosaka2:20170815173330p:plainf:id:tosaka2:20170815172920p:plainf:id:tosaka2:20170815173018p:plainf:id:tosaka2:20170815172348p:plainf:id:tosaka2:20170815172350p:plainf:id:tosaka2:20170815172352p:plainf:id:tosaka2:20170815172354p:plain

他にもたくさん良いのがあるんですがとりあえずこの辺で。

補足

自動で100枚も生成させたらサーバの負荷になるだろ!と言われそうなので一応捕捉しておくと、MakeGirls.moeの画像生成処理は全てブラウザ上で行われています。
WebAssemblyでモデルを動かしてくれるWebDNNというライブラリが使われています。また、Macの場合はGPUを利用して100倍高速に実行できるそうです。うらやましい...
WebGPUまたはWebGLが動作する環境ではそちらを使って超高速生成ができるようになりました(9/15)

詳しくは以下のFaster generationを見てください。
make.girls.moe

ひとりごと

Chromeデバッガを使わないといけない行はgetObject = () => t;の部分だけなので、このオブジェクトがデバッガ無しで取ってこれればもっと楽に導入できるんだけどなぁ。
どなたか即時関数内で定義されているオブジェクトを取得する方法を知っている人がいたらおしえてください。

wait_secは環境によって異なる値なのに直書きしてしまったのでハマってる人がいるみたいです。アレなプログラムで申し訳ない。改良するかも?
本当はPromiseオブジェクトとかを取ってきてちょうどいい時間待てれば良いんですが、それが簡単に出来るのかは未検証です。
→できました(8/16)

8/16 追記

GIFにしている方がいるようです。

口パク生成に関してはこちら

qiita.com

口パクボタン追加しました(8/16)
補間ボタンと同じように下に表示されている画像を選択してから実行してください。
f:id:tosaka2:20170816171053p:plain
他にも笑顔/メガネ/1枚削除(下の表示から)/ダブルクリックで保存機能をつけました。

gif生成サービスと組み合わせると楽しい!
GIFアニメーション画像作成ツール - フォトコンバイン f:id:tosaka2:20170816200705g:plain

8/29 追記

本家のアップデートに伴い、こちらの拡張プログラムで不具合が発生しているようです。
ちょっと今時間が取れないので、対応は先になると思います。(Expertモードでoptionの仕様が変わったのでそのあたりかなあとは思います。)

9/1 追記

現在のバージョンでも正しく動作するように修正しました。

9/12 追記

補完ボタンを押すとオプションレベルで補完するように変更しました。
オプション固定/ランダム化ボタンを追加しました。(ボタンを押してもUIには反映されませんが、generateや生成ボタンを押すと反映されていることがわかります。)