ひらめの日常

日常のメモをつらつらと

ChainerCVでSSDを学習させる時の注意点

はじめに

この記事は、自分がChainerCVを用いてSSDを学習させる時にハマった点や注意した点を紹介しています。
ChainerCVとは、その名の通りChainerをベースにコンピュータビジョンのタスクに対応させたものとなります。

SSDとは

簡単にSSDの説明をすると、SSDとはSingle Shot MultiBox Detectorの略で、

①VGGをbase networkとして異なる大きさのフィルターをかけていき
②異なるアスペクト比、解像度のbounding boxに対応した
③精度74.3%, 速度59FPSを記録した高い精度で高速

な物体検出方法です。それまでの先行研究に比べて学習時に物体らしさのを物体の分類と同時に学習するために、今までよりも高速な学習が可能になっています。

論文はこちら。
[1512.02325] SSD: Single Shot MultiBox Detector

また、こちらのスライドまとめがSSD含めた物体検出(object detection)の理解に非常に役立ちました。

www.slideshare.net

ChainerCVとは

ChainerCVは、ディープラーニング用ライブラリChainerのコンピュータービジョンに対応したもので、物体検出の手法が数多く実装されています。
GitHub - chainer/chainercv: ChainerCV: a Library for Deep Learning in Computer Vision

参考にした学習コードは、ChainerCVのexamplesにあるtrain.pyです。
github.com

注意点

MultiProcessIteratorのデッドロック

よし、gpuで学習を始めよう!と思うわけですが、学習が一向に進みません。SerialProcellIteratorを用いた時には正常に動くので、MultiProcessIterator特有の挙動のようです。

ChainerCVのissueを漁っていると同じような状況ハマっている方を見つけました。
github.com
どうやらMultiProcessIteratorが内部でcv2.resize()を呼んでおり、これが別スレッドを立てるためにデッドロックに陥るとのことです。

コメント通り、util.pycv2.resize()の直前にcv2.setNumThreads(0)を追加したところ、学習が開始されました。

Validationの適切なログ出力

ここでいう「適切」とは、自身に秘帖なタイミングで適切に出力するように変更しようということです。

デフォルトでは、10,000iterationごとにtestデータセットを用いてvalidation accuracyをlog出力しています。学習を120,000iteration回すので、それでも良いかもしれませんが、自分は頻繁にログを確認したかったので次のように変更しました。

trainer.extend(
        DetectionVOCEvaluator(
            test_iter, model, use_07_metric=True,
            label_names=voc_bbox_label_names_with_hands),
        trigger=(1000, 'iteration'))  # 1000iterationごとにvalidationのログ出力

頻繁にvalidation accuracyを出力する時ですが、かなり時間がかかるのでそこは留意したほうが良いかと思います。

使用できる学習済みモデルはimagenet

ChainerCVで実装されているSSDはVOCデータセットの20クラス分類に対応しています。
ですので、自作のデータセット含めた21クラス分類の転移学習(fine-tuning)がVOCデータセットの重みを用いてできるかと思ったのですが、現状はできません

こちらのissueにもあるように、20クラス分類以外の場合はVOCデータセットの重みは用いることができず、現在実装中のようです。
github.com

自作データセットを用いて20クラス分類以外の学習をさせたい時には、pretrained_modelはimagenetのままで学習させましょう。

if args.model == 'ssd300':
        model = SSD300(
            n_fg_class=21,  # 用途に応じて変更
            pretrained_model='imagenet')
    elif args.model == 'ssd512':
        model = SSD512(
            n_fg_class=21,  # 用途に応じて変更
            pretrained_model='imagenet')