hellkite 日記と雑記とメモ。

Shiki Kazamaの駄文と音楽と、時々技術な感じ

Chainerで画像カテゴリ分類(CIFAR-10を使った学習)


スポンサーリンク


ミスがあったので更新しました。 2016/10/25

すでにやりつくされている感がありますが、意外と学習から任意画像での評価までやっているサイトがなかったのでまとめておきます。学習させて用意されているテストデータを食べさせて何%の精度だったかより、ソリューション開発側としては、アプリに組み込んで動くかどうかの方が重要なので。

CIFAR-10データの準備と学習

というわけで、まずは、ChainerでCIFAR-10を学習してみます。基本は以下のサイトを参考にしました。
ai-programming.hatenablog.jp

ほぼそのままですが、少し修正して試してみます。

データの準備

まず、CIFAR-10を読み込みます。こちらは上のサイトと同じコードになりますが、読み込んだデータを画像として取り出してみました。画像として取り出すための処理は以下を参考にしました。
qiita.com

コードは以下の通りです。

import sklearn.datasets
import matplotlib
import matplotlib.pyplot as plt

import sys
import pickle
import os.path

def unpickle(file):
    fp = open(file, 'rb')
    data = pickle.load(fp, encoding='latin-1')
    fp.close()
    
    return data

X_train = None
y_train = []

for i in range(1,6):
    data_dic = unpickle(os.getcwd()+"\\data\\cifar-10\\data_batch_{0}".format(i))
    if i == 1:
        X_train = data_dic['data']
    else:
        X_train = np.vstack((X_train, data_dic['data']))
    y_train += data_dic['labels']
    
test_data_dic = unpickle(os.getcwd()+"\\data\\cifar-10\\test_batch")

# 評価用データ
X_test = test_data_dic['data']
X_test = X_test.reshape(len(X_test), 3, 32, 32)
y_test = np.array(test_data_dic['labels'])
X_train = X_train.reshape(len(X_train), 3, 32, 32)
y_train = np.array(y_train)

# ラベル名
batch_meta = unpickle(os.getcwd()+"\\data\\cifar-10\\batches.meta")
label_names = batch_meta['label_names']
label_names

# indexに入っている画像を描画(3×32×32のデータを32×32×3に変更)
index = 220
img = np.rollaxis(test_data_dic['data'][index].reshape((3,32,32)), 0, 3)
plt.title(label_names[y_test[index]])
plt.imshow(img, interpolation='none')

画像の出力結果はこんな感じ。

f:id:deis:20160923123641p:plain

学習

モデルと学習のコードは以下の通りです。ちょっとだけ修正しています。
話が逸れますがtqdmってライブラリがあるんですね。進捗が分かるようになっていたので感動しました。

import chainer
from chainer import cuda, Function, FunctionSet,gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import time

plt.style.use('ggplot')

# ニューラルネットワークの順伝播を記述
class Model(Chain):
    def __init__(self):
        super(Model, self).__init__(
            conv1=F.Convolution2D(3, 32, 3, pad=1),
            conv2=F.Convolution2D(32, 32, 3, pad=1),
            conv3=F.Convolution2D(32, 32, 3, pad=1),
            conv4=F.Convolution2D(32, 32, 3, pad=1),
            conv5=F.Convolution2D(32, 32, 3, pad=1),
            conv6=F.Convolution2D(32, 32, 3, pad=1),
            l1=F.Linear(512, 512),
            l2=F.Linear(512, 10)
        )
    def __call__(self, x):
        #x = chainer.Variable(x_data)
        h1 = F.relu(self.conv1(x))
        h2 = F.max_pooling_2d(F.relu(self.conv2(h1)), 2)
        h3 = F.relu(self.conv3(h2))
        h4 = F.max_pooling_2d(F.relu(self.conv4(h3)), 2)
        h5 = F.relu(self.conv5(h4))
        h6 = F.max_pooling_2d(F.relu(self.conv6(h5)), 2)
        h7 = F.dropout(F.relu(self.l1(h6)))
        y = self.l2(h7)
        return y

# Classifier Chain
model = L.Classifier(Model())

# Optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)

## 学習
from tqdm import tqdm

train_loss = []
train_acc = []
test_loss = []
test_acc = []

batchsize = 100

N = len(X_train)
N_test = len(X_test)

# 学習ループ
n_epoch = 30
for epoch in range(1, n_epoch+1):
    print("eopch", epoch)
    
    # training
    perm = np.random.permutation(N)
    sum_accuracy = 0
    sum_loss = 0
    # 0~Nまでのデータをバッチサイズごとに使って学習
    for i in tqdm(range(0, N, batchsize)):
        x = Variable(X_train[perm[i:i + batchsize]])
        t = Variable(y_train[perm[i:i + batchsize]])
        
        optimizer.update(model, x, t)
        
        sum_loss += float(model.loss.data) * len(t.data)
        sum_accuracy += float(model.accuracy.data) * len(t.data)

        train_loss.append(float(model.loss.data) * len(t.data))
        train_acc.append(float(model.accuracy.data) * len(t.data))
        
    # 訓練データの誤差と正解精度を表示
    print("train mean loss={0}, accuracy={1}".format(sum_loss / N, sum_accuracy / N))
    
    # 評価
    # テストデータの誤差と正解精度を算出し、汎化性能を確認
    sum_accuracy = 0
    sum_loss = 0
    for i in tqdm(range(0, N_test, batchsize)):
        x_batch = Variable(X_test[i:i+batchsize])
        y_batch = Variable(y_test[i:i+batchsize])

        loss = model(x, t)
        
        sum_loss += float(loss.data) * len(t.data)
        sum_accuracy += float(model.accuracy.data) * len(t.data)

        train_loss.append(float(loss.data) * len(t.data))
        train_acc.append(float(model.accuracy.data) * len(t.data))

    # テストデータでの誤差と正解精度を表示
    print("test mean loss={0}, accuracy={1}".format(sum_loss / N_test, sum_accuracy / N_test))

ちなみに、Trainerという機能を使った場合は、以下のコードになります。開発元のPFNの方が、Trainerを使ってください、とカンファレンスで言っていたので、できるだけこちらを利用した方がいいはず。ただ、残念ながら、私の環境では進捗がうまく表示されませんでしたが・・・コードのせいなのかな・・・。

from chainer import report, training, datasets, iterators
from chainer.training import extensions
from chainer.datasets import tuple_dataset

# データセットの準備
train = tuple_dataset.TupleDataset(X_train, y_train)
test = tuple_dataset.TupleDataset(X_test, y_test)

train_iter = iterators.SerialIterator(train, batch_size=100)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (30, 'epoch'), out='result')

# trainerを設定
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())

trainer.run()

学習時のエポック数は30にしました。学習結果は精度が90%程度でした。元サイトの精度グラフの30あたりを確認すると90%となっているので、妥当な値になっているようです。
ちなみに学習時間はCPU利用で6時間程度。このレベルならまだCPUでがんばれます。。

学習ファイルの保存(2016/10/25更新)

以下のコードで学習したモデルを保存します。次からは、この学習データを使って評価することができます。

# 学習データの保存

# 古いやり方らしく、うまく動作しない
# model.to_cpu()
# with open('model.pkl', 'wb') as o:
#     pickle.dump(model, o)

serializers.save_npz('model_cifar10.model', model)
serializers.save_npz('model_cifar10.state', optimizer)

その他

ちなみに学習の最中にMemoryErrorが発生することが何度かありました。そのため、Jupyterを利用して検証しながらやっていましたが、最終的にはコマンドプロンプトで実行しました。それでも発生することがありましたが、エラーの頻度は減ったようです。原因はよくわかりませんが、PCを再起動するとでなくなったりします。

次は、学習ファイルを読み込んで任意の画像ファイルが認識できるか確認してみます。
↓更新しました。
hellkite.hatenablog.com