[deep learning] caffeとCIFAR-10を使って画像判別テスト(詳細)
昨日のポストの詳細版です。
portaltan.hatenablog.com
こちらではcifarが用意してくれているscriptで一気にやってくれているところがおおかったので、そこらへんをひとつひとつのコマンドにわけて実行していきます。
また、今回は自分のhomeディレクトリでプロジェクトを作成します
1. 画像の用意(CIFAR-10のダウンロード)
学習に使うための教師データを作るには、その元となる画像データが必要です
この画像データを作成するのがdeep learningではとても難しいのですが、今回はdeep learning用に10種類のカテゴリに分類された画像集のCIFAR-10を利用します
$ cd ~ $ mkdir test_cifar $ wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz $ tar -xf cifar-10-binary.tar.gz $ ll drwxr-xr-x 2 4096 Jun 5 2009 cifar-10-batches-bin -rw-r--r-- 1 170052171 Jun 5 2009 cifar-10-binary.tar.gz $ rm -i cifar-10-binary.tar.gz # もういらないので $ ll cifar-10-batches-bin/ total 180080 -rw-r--r-- 1 61 Jun 5 2009 batches.meta.txt -rw-r--r-- 1 30730000 Jun 5 2009 data_batch_1.bin -rw-r--r-- 1 30730000 Jun 5 2009 data_batch_2.bin -rw-r--r-- 1 30730000 Jun 5 2009 data_batch_3.bin -rw-r--r-- 1 30730000 Jun 5 2009 data_batch_4.bin -rw-r--r-- 1 30730000 Jun 5 2009 data_batch_5.bin -rw-r--r-- 1 88 Jun 5 2009 readme.html -rw-r--r-- 1 30730000 Jun 5 2009 test_batch.bin
- データはjpg等ではなくbinaryとしてダウンロードされます
2. 画像の加工
CIFAR-10のデータをcaffeに利用できるように変換します
訓練データとテストデータの作成
コマンドusage
### convert_cifar_data.bin $DATA $EXAMPLE $DBTYPE # $DATA: dataファイルのあるディレクトリ # $EXAMPLE: 教師データ出力先 # $DBTYPE: lmdbかleveldbを選択
実際の手順
$ ${CAFFE_ROOT}/build/examples/cifar10/convert_cifar_data.bin cifar-10-batches-bin ./ lmdb $ ll drwxr-xr-x 2 4096 Jun 5 2009 cifar-10-batches-bin drwxr--r-- 2 4096 Feb 26 17:50 cifar10_test_lmdb drwxr--r-- 2 4096 Feb 26 17:50 cifar10_train_lmdb $ rm -rf cifar-10-batches-bin # もういらないので
- cifar10_train_lmdb(訓練用画像データ)が作成されました
- cifar10_test_lmdb(テスト用画像データ)が作成されました
平均画像の作成
コマンドusage
### compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE] # $INPUT_DB: 訓練データ # $OUTPUT_FILE: 平均画像出力先
実際の手順
$ ${CAFFE_ROOT}/build/tools/compute_image_mean -backend=lmdb cifar10_train_lmdb mean.binaryproto $|| drwxr--r-- 2 4096 Feb 26 17:50 cifar10_test_lmdb drwxr--r-- 2 4096 Feb 26 17:50 cifar10_train_lmdb -rw-r--r-- 1 12299 Feb 26 17:54 mean.binaryproto
- mean.binaryproto(平均画像データ)が作成されました
3. 学習
学習パラメータ設定ファイルの作成
$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick_solver.prototxt solver.prototxt
- 今回はデフォルトのものを使用する
- ファイルを開いてパスとかをカレントに変更してください
ネットワーク定義ファイルの作成
- 今回はデフォルトのものを使用する
$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick_train_test.prototxt train_test.prototxt
- 今回はデフォルトのものを使用する
- ファイルを開いてパスとかをカレントに変更してください
学習
コマンドusage
usage: caffe <command> <args> commands: train train or finetune a model test score a model device_query show GPU diagnostic information time benchmark model execution time
- caffeファイル
- =>(利用)学習パラメーター設定ファイル(solver.prototxt)
- =>(利用)ネットワーク定義ファイル(train_test.prototxt)
- =>(利用)平均画像(mean.binaryproto)
- =>(利用)訓練データ(***_train_lmdb)
- =>(利用)検証データ(***_test_lmdb)
- =>(利用)ネットワーク定義ファイル(train_test.prototxt)
- =>(利用)学習パラメーター設定ファイル(solver.prototxt)
実際の手順
${CAFFE_ROOT}/build/tools/caffe train --solver solver.prototxt ### GPU環境で数分かかります ### $ ll -rw-r--r-- 1 600032 Feb 26 18:11 cifar10_quick_iter_4000.caffemodel.h5 # 学習モデル -rw-r--r-- 1 590064 Feb 26 18:11 cifar10_quick_iter_4000.solverstate.h5 drwxr--r-- 2 4096 Feb 26 17:50 cifar10_test_lmdb drwxr--r-- 2 4096 Feb 26 17:50 cifar10_train_lmdb -rw-r--r-- 1 12299 Feb 26 17:54 mean.binaryproto -rwxrwxr-x 1 833 Feb 26 18:00 solver.prototxt -rwxrwxr-x 1 3020 Feb 26 18:03 train_test.prototxt $ rm -rf cifar10_train_lmdb # もういらないので $ rm -rf cifar10_test_lmdb # もういらないので
- caffemodel(学習データ)が作成されました
4. 学習データの利用
判別したい画像ファイルの用意
$ wget -O flog.jpg https://upload.wikimedia.org/wikipedia/commons/0/0c/Green_flog_lithobates_clamitans.jpg
- 適当に探したものをダウンロードしてきます
ネットワーク定義のコピー
$ cp -ip ${CAFFE_ROOT}/examples/cifar10/cifar10_quick.prototxt .
判別プログラムの作成
vi classify.py ============================================= #!/usr/bin/env python import sys import caffe from caffe.proto import caffe_pb2 import numpy # for array net_path = 'cifar10_quick.prototxt' model_path = 'cifar10_quick_iter_4000.caffemodel.h' mean_path = 'mean.binaryproto' cifar_map = { 0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "hourse", 8: "ship", 9: "truck" } mean_blob = caffe_pb2.BlobProto() with open(mean_path) as f: mean_blob.ParseFromString(f.read()) mean_array = numpy.asarray( mean_blob.data, dtype=numpy.float32 ).reshape( (mean_blob.channels, mean_blob.height, mean_blob.width) ) classifier = caffe.Classifier( net_path, model_path, mean=mean_array, raw_scale=255) # sys.argv[0] is script name # sys.argv[1] is image file image = caffe.io.load_image(sys.argv[1]) # predict(target_image, oversample=True|False) # oversample's default value is True predictions = classifier.predict([image], oversample=False) answer = numpy.argmax(predictions) # get max value's index # RESULT print("====================================") print("possibility of each categoly") for index, prediction in enumerate(predictions[0]): print (str(index)+"("+cifar_map[index]+"): ").ljust(15) + str(prediction) print("====================================") print("I guess this image is [" + cifar_map[answer] + "]" =============================================
実行
$ python classify.py cat.jpg ・ ・ ・ ==================================== possibility of each categoly 0(airplane): 6.39476e-06 1(automobile): 6.76981e-05 2(bird): 0.0031082 3(cat): 0.63886 4(deer): 0.0967892 5(dog): 0.135272 6(frog): 0.0638987 7(hourse): 0.0619933 8(ship): 1.39324e-06 9(truck): 3.50305e-06 ==================================== I guess this image is [cat]
正しく猫だと分類されました
ちなみにこれは犬に分類されました
]
==================================== possibility of each categoly 0(airplane): 1.28793e-08 1(automobile): 2.60279e-09 2(bird): 0.000465949 3(cat): 0.0388079 4(deer): 0.00235302 5(dog): 0.956003 6(frog): 0.00146607 7(hourse): 0.000904057 8(ship): 2.6331e-09 9(truck): 6.08218e-09 ==================================== I guess this image is [dog]
こちらは猫の確率が69%。でも犬要素も27%と結構高めになってます
==================================== possibility of each categoly 0(airplane): 1.71014e-06 1(automobile): 3.18347e-06 2(bird): 1.51628e-06 3(cat): 0.693822 4(deer): 0.000347697 5(dog): 0.276983 6(frog): 0.00078709 7(hourse): 0.0280534 8(ship): 3.01918e-08 9(truck): 3.12128e-07 ==================================== I guess this image is [cat]
以上です