教師あり学習の基本

333
NO IMAGE

以前、社内教育の一環として機械学習コンテストを行いました。
機械学習に馴染みのないメンバーもいるため、クラスを未経験者クラス、経験者クラスの二つに分けました。

未経験者クラスのお題は、教師あり学習の基本中の基本であるアヤメの分類です。
アヤメの分類とは、アヤメの4つの特徴量(がく片の長さ、がく片の幅、花びらの長さ、花びらの幅)から、アヤメの品種(Setosa、Versicolour、Virginica)を推定する課題です。

今回は、このアヤメの分類のお題を使って教師あり学習を行います。

教師あり学習

教師あり学習では、データの持つ特徴量と正解の組を入力として学習を行います。
教師あり学習には、分類と回帰の2つがありますが、今回は分類を扱います。
分類を簡単にイメージで説明すると、下図のようなデータを分類する境界線(通常データは2次元以上なので線ではなく超平面)の学習です。

この境界線を学習することにより、新規のデータに対してその分類を推論することができるようになります。

classify.png

このようなデータの分類を行うアルゴリズムには色々なものがありますが、今回は代表的なものであるサポートベクトルマシン(SVM)を使います。

プログラム

データの作成

まずはじめに、UCI Machine Learning Repositoryから、Iris Data Setをダウンロードします。

ダウンロードしたファイルは、以下のような150行のCSVファイルで、各列が

  • がく片の長さ
  • がく片の幅
  • 花びらの長さ
  • 花びらの幅
  • 品種
    となっています。
// data.csv

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
...
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
...
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

このデータから、

  • 学習に使う教師データ
  • 検証に使うテストデータ
  • 検証に使う正解データ
    の3つのファイルを生成します。

データの作成は、scikit-learnのtrain_test_splitを使えば簡単にできるのですが、大したプログラムではないので今回は自作しました。

# create_training_data.py

import sys
import csv
import random
import argparse

def main(args):
    datas = _read_input(args.input_file)
    random.shuffle(datas)
    _create_training_data(datas, args.training_data_cnt,
                        args.training_file, args.test_file, args.answer_file)

def _read_input(path):
    datas = []
    with open(path, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            datas.append(row)
    return datas

def _create_training_data(datas, train_cnt, path_training, path_test, path_answer):
    # 学習データ
    with open(path_training, 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        for i in range(train_cnt):
            writer.writerow(datas[i])
    # テストデータ
    with open(path_test, 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        for i in range(train_cnt, len(datas)):
            writer.writerow(datas[i][:-1])
    # 正解データ
    with open(path_answer, 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        for i in range(train_cnt, len(datas)):
            writer.writerow(datas[i][-1:])

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='create data')
    parser.add_argument('-i', '--input_file', required=True,
                        help='input data path')
    parser.add_argument('-t', '--training_file', required=True,
                        help='training data path')
    parser.add_argument('-e', '--test_file', required=True,
                        help='test data path')
    parser.add_argument('-a', '--answer_file', required=True,
                        help='answer data path')
    parser.add_argument('-n', '--training_data_cnt', required=True,
                        type=int,
                        help='training data count')
    args = parser.parse_args()
    main(args)

学習と検証

機械学習のためのライブラリは多く存在していますが、今回はその中の一つであるscikit-learnを使います。

使い方は非常に簡単で、学習するときは

clf = svm.SVC(kernel='linear')
clf.fit(x_train, y_train)

推論するときは

pred = clf.predict(x_test)

とするだけです。

実際には様々なパラメーターを指定できますが、今回は簡単な課題なので、ほぼデフォルト値で実行します。パラメーターの詳細はこちらで確認できます。

学習、検証プログラムの全体はこのようになります。

# train_test.py

from sklearn import svm

import csv

def main(args):
    # 特徴ベクトルと正解ラベル
    x_train, y_train = _load_train(args.training_file)
    # 線形なSVMによる分類器(ソフトマージンにおける定数はC=1.)
    clf = svm.SVC(kernel='linear')
    # 訓練データによる学習(超平面の決定)
    clf.fit(x_train, y_train)
    # テストデータの分類を推論
    x_test = _load_test(args.test_file)
    pred = clf.predict(x_test)
    # 結果を出力
    with open(args.result_file, "w") as f:
        for x in pred:
            f.write("{}\n".format(x))

def _load_train(path):
    attrs = []
    classes = []
    with open(path, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            attrs.append(row[:-1])
            classes += row[-1:]
    return attrs, classes

def _load_test(path):
    attrs = []
    with open(path, "r") as f:
        reader = csv.reader(f)
        for row in reader:
            attrs.append(row)
    return attrs

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='training and test')
    parser.add_argument('-t', '--training_file', required=True,
                        help='training data path')
    parser.add_argument('-e', '--test_file', required=True,
                        help='test data path')
    parser.add_argument('-r', '--result_file', required=True,
                        help='result data path')
    args = parser.parse_args()
    main(args)

結果

コンテストでは、アルゴリズムの指定は特にしなかったのですが、ほぼ全員がSVMを使って実装をし、簡単なデータセットだったこともあり正解率も100%でした(1名だけk-meansを使って実装していました)。

ちなみに経験者クラスには、被験者のもつ279個の属性(年齢、性別、身長、体重、心電図から得られる様々な特徴)を元に、その人が不整脈(14種類)、不整脈ではない、未分類のいずれかを推定するお題を出しましたが、こちらはデータが複雑なため100%の正解率には至りませんでした。

UCI Machine Learning Repositoryには様々なデータセットがあるので、いろいろと遊んでみるのも面白いかもしれません。