chakokuのブログ(rev4)

テック・コミック・DTM・・・ごくまれにチャリ

【DLfS】俺NNの学習状況→10時間稼働でaccuracy:0.4〜0.5

DLfS本を勉強しながら、数値的微分による勾配法で俺NNを実装して学習させています。数値的微分による勾配降下法はアルゴリズムが理解しやすく簡単に実装できる反面、時間がかかるのはしょうがない。それは十分承知の上とはいえ、おおよそ10時間ぐらい走らせてトレーニング回数は700回ぐらいで、損失値が1.2-1.4、精度(accuracy)が0.5ぐらい。左のグラフがその傾向を示した図で、青の線が損失(左側Y軸の目盛り)で緑の線が精度(右側Y軸の目盛り)。accuracy:0.5とは、10パターン読むと5パターン成功する精度です。(accuracyの算出は学習用には使っていない検証用データ100パターンをNNに与えてパターンを推定させて計算しています)。

下記は実際の学習時の表示例。学習に使ったデータの正解がans:で、prd:がNNによる推定結果。表示上 n of training:421とあるのは途中で中断してカウンターが0に戻ったから。実際は300+421ぐらい行っている。

この効率の悪い数値的微分による実装であまり頑張っても意味ないとはいえるが、希望としては認識率90%〜95%ぐらいまでにはなってほしい。MNISTのデータが人間が読んでも分かりづらいとはいえ。。

n of training:421
loss:1.225226
ans:[9, 5, 1, 5, 5, 7, 9, 9, 0, 8]
prd:[9, 5, 1, 5, 9, 9, 9, 9, 0, 9]
accuracy:0.390000
---------------
n of training:422
loss:1.450918
ans:[6, 8, 2, 8, 1, 9, 5, 7, 4, 8]
prd:[8, 8, 8, 8, 1, 7, 8, 7, 4, 8]
accuracy:0.460000
---------------
n of training:423
loss:1.236774
ans:[8, 8, 5, 6, 7, 1, 0, 7, 8, 3]
prd:[8, 8, 8, 8, 7, 1, 0, 7, 1, 8]
accuracy:0.420000

■追記(161120)
その後、学習させているが、、数日走らせてもaccuracy 0.9に到達しそうにない。0.9の前で飽和の(というのか、いわゆるサチっている)ように思える。

■追記(151127)
日中家に居ないので、夜間連続稼働させて学習させた。約5000回学習させた結果、accuracyが0.9でサチっている。多分このネットワーク構造ではこれぐらいが上限ではないかと思われる。ちなみに、1回の学習はMNIST10パターンをバッチで与えて学習させている。1回の学習に74秒程度かかるので、5000回学習させるのに連続稼働で4.2日
(74s * 5000 = 370000s)
(370000S =6167M =103H --> 4.2Days)
途中からaccuracyが下がっているのは検証パターンを100パターンから1000パターンに増やしたため。増やすとなぜ下がるのか??詳細は分からないけど、さらにばらつきが出たということか。。

パターンの識別状況を確認
5千回学習させて、おおよそaccuracy0.8後半まで到達したNNでパターン認識の動作を確認した。用いたのはMNISTの検証用集合。例として正解4を確認
左は正しく4と認識されたパターンの例、4の数値は9との読み間違いが多いと考えられる。

左は誤って認識されたパターンの例
上の行は、正解が7,8,9のずれかなのに4と誤って判断したパターン。下2段目から5段目は、それぞれ、4なのに0と誤判断、1と誤判断、6と誤判断、9と誤判断した例

人間だったら、4なのか9なのか?の識別は、○と縦棒は同じとして、4の場合は上から左に降りる斜めの線があるかどうか、縦棒から右に出る線があるかどうかで判断すると思いますが、俺NNにはそんな繋がり情報は持ち合わせないので、単にそのXY座標上に有効データがあるかどうかだけ。。過去の学習パターンにそのXY座標上に有効データがあったのなら該当文字とみなすしなければ見なさない。過去に見ていれば正解できるし、見たことがなければ正解できない。。
このような線と線との連続性が考慮されない限界を解消するのが、DLfS本後半に出てくる畳み込みレイヤーと理解してます。

ひとまず、数値微分での勾配降下法による学習は実装できたと判断、DLfS本、4章は終わりということで、、5章のバックプロパゲーションに進む予定。