脳汁portal

アメリカ在住(だった)新米エンジニアがその日学んだIT知識を書き綴るブログ

[deep learning] caffeとCIFAR-10を使って画像判別テスト

こちらのAITC様のスライドに従って画像の判別テストをしてみました
http://www.slideshare.net/yasuyukisugai/deep-learningcaffe

環境

  • CUDA 7.5
  • Caffe 1.0.0rc3

手順

1. 学習用イメージのダウンロード

CIFAR-10は10個のカテゴリーに分類された画像集で、画像認識の学習に最適
CIFAR-10 and CIFAR-100 datasets

$ cd ${CAFE_ROOT}
$ cd data/cifar10
$ ./get_cifar10.sh
$ ll
total 180080
-rw-r--r-- 1       61 Jun  5  2009 batches.meta.txt
-rwxrwxr-x 1      504 Feb 23 13:01 get_cifar10.sh
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_1.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_2.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_3.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_4.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_5.bin
-rw-r--r-- 1       88 Jun  5  2009 readme.html
-rw-r--r-- 1 30730000 Jun  5  2009 test_batch.bin
  • cifar-10-binary.tar.gzをダウンロード
  • ダウンロードしたファイルを解凍
  • 画像はpngとかじゃなくてbinary状態でダウンロードされる
2. Caffe用に画像を加工

cifarのデータをcaffe用に変換する

$ cd ${CAFE_ROOT}
$ ./example/cifar10/create_cifar10.sh
 
# 以下の3ディレクトリ(ファイル)が作成される
$ ls -ltr examples/cifar10/ | tail -3
drwxr--r-- 2  4096 Feb 24 16:33 cifar10_train_lmdb # 訓練用データ(訓練だから容量でかい)
drwxr--r-- 2  4096 Feb 24 16:33 cifar10_test_lmdb  # 検証用データ(検証用にまったく知らないデータを残しておく・容量小さい)
-rw-r--r-- 1 12299 Feb 24 16:33 mean.binaryproto   # たぶん平均画像
  • data/cifar10のデータを使ってlmdbにデータ格納
    • 過去に作成したデータは削除される(rm -rf $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/cifar10_test_$DBTYPE)
    • data/cifar10のデータをcifar用データに変換(./build/examples/cifar10/convert_cifar_data.bin $DATA $EXAMPLE $DBTYPE)
  • 平均画像作成(./build/tools/compute_image_mean -backend=$DBTYPE $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/mean.binaryproto)
3. 学習開始

caffe用に変換したcifarの画像セットを使って学習を開始する

$ build/tools/caffe train --solver examples/cifar10/cifar10_quick_solver.prototxt
  # => 設定ファイルとかはデフォルトのものを使う
  .
  .
  .
  I0224 16:35:43.400415 25580 solver.cpp:338] Iteration 0, Testing net (#0)
  I0224 16:35:46.022946 25580 solver.cpp:406]     Test net output #0: accuracy = 0.1008 # 正解率(段々あがっていく)
  I0224 16:35:46.023041 25580 solver.cpp:406]     Test net output #1: loss = 2.30249 (* 1 = 2.30249 loss) # 予測と正解が一致すると小さくなる(段々減っていく)
  .
  .
  .
  I0224 16:40:10.214978 25580 solver.cpp:318] Iteration 4000, loss = 0.597383
  I0224 16:40:10.215018 25580 solver.cpp:338] Iteration 4000, Testing net (#0)
  I0224 16:40:12.848549 25580 solver.cpp:406]     Test net output #0: accuracy = 0.7091 # 正答率 70%くらい
  I0224 16:40:12.848611 25580 solver.cpp:406]     Test net output #1: loss = 0.863376 (* 1 = 0.863376 loss)
  I0224 16:40:12.848623 25580 solver.cpp:323] Optimization Done.
  I0224 16:40:12.848631 25580 caffe.cpp:222] Optimization Done.
 
$ ls -ltr example/cifar10/ | tail -2
  # => 学習データはファイルとして記録される(サーバとかrebootしても再学習しなくてよい)
  # => ファイル名は[prefix+何回目の学習か+拡張子]
  -rw-r--r-- 1 600032 Feb 24 17:07 cifar10_quick_iter_4000.caffemodel.h5
  -rw-r--r-- 1 590064 Feb 24 17:07 cifar10_quick_iter_4000.solverstate.h5 



ちなみに学習プロセスのログは/tmpディレクトリに保存されるのであとでも確認できる

$ ls -trl /tmp | tail -2

  lrwxrwxrwx 1    59 Feb 24 17:03 caffe.INFO -> caffe.titanx2.log.INFO.20160224-170316.25715
  -rw-r--r-- 1 37620 Feb 24 17:07 caffe.titanx2.log.INFO.20160224-170316.25715
 
$ cat /tmp/caffe.INFO | grep "accuracy = " | tail
  I0224 17:03:20.035547 25715 solver.cpp:406]     Test net output #0: accuracy = 0.0882
  I0224 17:03:51.368964 25715 solver.cpp:406]     Test net output #0: accuracy = 0.5498
  I0224 17:04:22.786237 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6397
  I0224 17:04:54.381880 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6787
  I0224 17:05:26.168318 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6914
  I0224 17:05:58.029224 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6905
  I0224 17:06:29.850028 25715 solver.cpp:406]     Test net output #0: accuracy = 0.693
  I0224 17:07:01.673884 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6881
  I0224 17:07:33.700891 25715 solver.cpp:406]     Test net output #0: accuracy = 0.6933
4. 検証

作成したcaffeの学習モデルを利用して検証を行う

$ wget https://upload.wikimedia.org/wikipedia/commons/0/0c/Green_flog_lithobates_clamitans.jpg
  # => google画像検索で適当に再利用許可の画像を検索してダウンロード
 
$ python cifar10_classifier.py Green_flog_lithobates_clamitans.jpg
  .
  .
  [[  1.49867356e-05   1.31496192e-09   3.60405451e-04   7.75128836e-04
      1.04495348e-05   1.60585751e-05   9.98820961e-01   2.31061165e-07
      1.51424626e-06   2.83154435e-07]]
    # => 各カテゴリ毎の類似度を示す

  6:frog
    # 6番目のfrogである確率が99.8%以上と一番高い
  • CIFAR10のカテゴリは以下の通りです

f:id:portaltan:20160225094053p:plain

    • 0: "airplane"
    • 1: "automobile"
    • 2: "bird"
    • 3: "cat"
    • 4: "deer"
    • 5: "dog"
    • 6: "frog"
    • 7: "hourse"
    • 8: "ship"
    • 9: "truck"
  • 検証に利用した画像は以下です

f:id:portaltan:20160225093514p:plain

  • cifar10_classifier.pyのスクリプトはスライドをご確認ください