脳汁portal

アメリカ在住(だった)新米エンジニアがその日学んだIT知識を書き綴るブログ

[deep learning] caffeとCIFAR-10を使って画像判別テスト(詳細)

昨日のポストの詳細版です。
portaltan.hatenablog.com
こちらではcifarが用意してくれているscriptで一気にやってくれているところがおおかったので、そこらへんをひとつひとつのコマンドにわけて実行していきます。
また、今回は自分のhomeディレクトリでプロジェクトを作成します

1. 画像の用意(CIFAR-10のダウンロード)

学習に使うための教師データを作るには、その元となる画像データが必要です
この画像データを作成するのがdeep learningではとても難しいのですが、今回はdeep learning用に10種類のカテゴリに分類された画像集のCIFAR-10を利用します

$ cd ~
$ mkdir test_cifar

$ wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
$ tar -xf cifar-10-binary.tar.gz
$ ll
drwxr-xr-x 2      4096 Jun  5  2009 cifar-10-batches-bin
-rw-r--r-- 1 170052171 Jun  5  2009 cifar-10-binary.tar.gz
$ rm -i cifar-10-binary.tar.gz # もういらないので
$ ll cifar-10-batches-bin/
total 180080
-rw-r--r-- 1       61 Jun  5  2009 batches.meta.txt
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_1.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_2.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_3.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_4.bin
-rw-r--r-- 1 30730000 Jun  5  2009 data_batch_5.bin
-rw-r--r-- 1       88 Jun  5  2009 readme.html
-rw-r--r-- 1 30730000 Jun  5  2009 test_batch.bin
  • データはjpg等ではなくbinaryとしてダウンロードされます

2. 画像の加工

CIFAR-10のデータをcaffeに利用できるように変換します

訓練データとテストデータの作成

コマンドusage
### convert_cifar_data.bin $DATA $EXAMPLE $DBTYPE
  # $DATA: dataファイルのあるディレクトリ
  # $EXAMPLE: 教師データ出力先
  # $DBTYPE: lmdbかleveldbを選択
  • 内部で読み込むライブラリが相対パスになっているので、これはコピーせずにcaffeディレクトリのスクリプトを実行します
実際の手順
$ ${CAFFE_ROOT}/build/examples/cifar10/convert_cifar_data.bin cifar-10-batches-bin ./ lmdb
$ ll
drwxr-xr-x 2  4096 Jun  5  2009 cifar-10-batches-bin
drwxr--r-- 2  4096 Feb 26 17:50 cifar10_test_lmdb
drwxr--r-- 2  4096 Feb 26 17:50 cifar10_train_lmdb

$ rm -rf cifar-10-batches-bin # もういらないので
  • cifar10_train_lmdb(訓練用画像データ)が作成されました
  • cifar10_test_lmdb(テスト用画像データ)が作成されました

平均画像の作成

コマンドusage
### compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]
  # $INPUT_DB: 訓練データ
  # $OUTPUT_FILE: 平均画像出力先
実際の手順
$ ${CAFFE_ROOT}/build/tools/compute_image_mean -backend=lmdb cifar10_train_lmdb mean.binaryproto
$||
drwxr--r-- 2  4096 Feb 26 17:50 cifar10_test_lmdb
drwxr--r-- 2  4096 Feb 26 17:50 cifar10_train_lmdb
-rw-r--r-- 1 12299 Feb 26 17:54 mean.binaryproto
  • mean.binaryproto(平均画像データ)が作成されました

3. 学習

学習パラメータ設定ファイルの作成

$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick_solver.prototxt solver.prototxt
  • 今回はデフォルトのものを使用する
  • ファイルを開いてパスとかをカレントに変更してください

ネットワーク定義ファイルの作成

  • 今回はデフォルトのものを使用する
$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick_train_test.prototxt train_test.prototxt
  • 今回はデフォルトのものを使用する
  • ファイルを開いてパスとかをカレントに変更してください

学習

コマンドusage
usage: caffe <command> <args>
commands:
  train           train or finetune a model
  test            score a model
  device_query    show GPU diagnostic information
  time            benchmark model execution time
  • caffeファイル
    • =>(利用)学習パラメーター設定ファイル(solver.prototxt)
      • =>(利用)ネットワーク定義ファイル(train_test.prototxt)
        • =>(利用)平均画像(mean.binaryproto)
        • =>(利用)訓練データ(***_train_lmdb)
        • =>(利用)検証データ(***_test_lmdb)
実際の手順
${CAFFE_ROOT}/build/tools/caffe train --solver solver.prototxt
### GPU環境で数分かかります ###
$ ll
-rw-r--r-- 1 600032 Feb 26 18:11 cifar10_quick_iter_4000.caffemodel.h5  # 学習モデル
-rw-r--r-- 1 590064 Feb 26 18:11 cifar10_quick_iter_4000.solverstate.h5
drwxr--r-- 2   4096 Feb 26 17:50 cifar10_test_lmdb
drwxr--r-- 2   4096 Feb 26 17:50 cifar10_train_lmdb
-rw-r--r-- 1  12299 Feb 26 17:54 mean.binaryproto
-rwxrwxr-x 1    833 Feb 26 18:00 solver.prototxt
-rwxrwxr-x 1   3020 Feb 26 18:03 train_test.prototxt

$ rm -rf cifar10_train_lmdb # もういらないので
$ rm -rf cifar10_test_lmdb  # もういらないので
  • caffemodel(学習データ)が作成されました

4. 学習データの利用

判別したい画像ファイルの用意
$ wget -O flog.jpg https://upload.wikimedia.org/wikipedia/commons/0/0c/Green_flog_lithobates_clamitans.jpg
  • 適当に探したものをダウンロードしてきます
ネットワーク定義のコピー
$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick.prototxt .
判別プログラムの作成
vi classify.py
=============================================
#!/usr/bin/env python

import sys
import caffe
from caffe.proto import caffe_pb2
import numpy # for array

net_path = 'cifar10_quick.prototxt'
model_path = 'cifar10_quick_iter_4000.caffemodel.h'
mean_path = 'mean.binaryproto'

cifar_map = {
        0: "airplane",
        1: "automobile",
        2: "bird",
        3: "cat",
        4: "deer",
        5: "dog",
        6: "frog",
        7: "hourse",
        8: "ship",
        9: "truck"
}


mean_blob = caffe_pb2.BlobProto()

with open(mean_path) as f:
    mean_blob.ParseFromString(f.read())

mean_array = numpy.asarray(
    mean_blob.data,
    dtype=numpy.float32
).reshape(
    (mean_blob.channels, mean_blob.height, mean_blob.width)
)

classifier = caffe.Classifier(
    net_path,
    model_path,
    mean=mean_array,
    raw_scale=255)

# sys.argv[0] is script name
# sys.argv[1] is image file
image = caffe.io.load_image(sys.argv[1])

# predict(target_image, oversample=True|False)
# oversample's default value is True
predictions = classifier.predict([image], oversample=False)
answer = numpy.argmax(predictions) # get max value's index

# RESULT
print("====================================")
print("possibility of each categoly")
for index, prediction in enumerate(predictions[0]):
    print (str(index)+"("+cifar_map[index]+"): ").ljust(15) + str(prediction)
print("====================================")
print("I guess this image is [" + cifar_map[answer] + "]"
=============================================
実行

f:id:portaltan:20160301113811j:plain:w200

$ python classify.py cat.jpg
・
・
・
====================================
possibility of each categoly
0(airplane):   6.39476e-06
1(automobile): 6.76981e-05
2(bird):       0.0031082
3(cat):        0.63886
4(deer):       0.0967892
5(dog):        0.135272
6(frog):       0.0638987
7(hourse):     0.0619933
8(ship):       1.39324e-06
9(truck):      3.50305e-06
====================================
I guess this image is [cat]

正しく猫だと分類されました

ちなみにこれは犬に分類されました
f:id:portaltan:20160301113929j:plain:w200]

====================================
possibility of each categoly
0(airplane):   1.28793e-08
1(automobile): 2.60279e-09
2(bird):       0.000465949
3(cat):        0.0388079
4(deer):       0.00235302
5(dog):        0.956003
6(frog):       0.00146607
7(hourse):     0.000904057
8(ship):       2.6331e-09
9(truck):      6.08218e-09
====================================
I guess this image is [dog]


こちらは猫の確率が69%。でも犬要素も27%と結構高めになってます
f:id:portaltan:20160301114026j:plain:w200

====================================
possibility of each categoly
0(airplane):   1.71014e-06
1(automobile): 3.18347e-06
2(bird):       1.51628e-06
3(cat):        0.693822
4(deer):       0.000347697
5(dog):        0.276983
6(frog):       0.00078709
7(hourse):     0.0280534
8(ship):       3.01918e-08
9(truck):      3.12128e-07
====================================
I guess this image is [cat]

以上です