joppot

コピペで絶対動く。説明を妥協しない

プログラミング

scikit-learnの4つの関数で機械学習などのデータを前処理する

投稿日:

Pocket

概要

皆んさんこんにちはcandleです。今回はpythonの機械学習ライブラリ『scikit-learn』を使い、データの前処理をしてみます。
scikit-learnでは変換器と呼ばれるものを使い、入力されたデータセットをfit_transform()メソッドで変換することができます。
変換器はたくさんあるので、機械学習でよく使われる以下の4つの変換器を紹介します。

Imputer
StandardScaler
MinMaxScaler
OneHotEncorder


前提

Python3
scikit-learn 0.19.1

サンプルコードを動かす場合はこれとは別にnumpyが必要です。


Imputer

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Imputer.html#sklearn.preprocessing.Imputer

imuterはデータの中に含まれる欠損値(None)を指定した他の値に置換します。
引数には以下の値がデフォルトで設定されています。

Imputer(missing_values=’NaN’, strategy=’mean’, axis=0, verbose=0, copy=True)

missing_values はfloat型で、指定した値に該当するデータ内のすべての値を置換します。
Noneではない他の実数を置換したい場合に使います。
strategy はstr型で、mean (平均値)、median(中央値), mode(最頻値)を設定します。
Axis はint型で、0を指定すると列(縦)の平均値、1を指定すると行(横)の平均値で置換します。

試しに、適当なフォルダで、imputer_test.pyを作ります。

touch imputer_test.py

以下を書き込みましょう。
サンプルコード

from sklearn.preprocessing import Imputer
import numpy as np
data = np.array([[7, 2, 3],
                 [8, None, 3],
                 [3, 8, 5]])
imputer = Imputer()
new_data = imputer.fit_transform(data)
print(new_data)

実行します。

python3 imputer_test.py
[[ 7.  2.  3.]
 [ 8.  5.  3.]
 [ 3.  8.  5.]]

Noneであったところが5に置換されています。

StandardScaler

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html#sklearn.preprocessing.StandardScaler

データを標準化します。引数には以下の値がデフォルトで設定されています。

StandardScaler(copy=True, with_mean=True, with_std=True)

ファイルを作成します。

touch ss.py

サンプルコード

from sklearn.preprocessing import StandardScaler
import numpy as np
data = np.array([[7., 2., 3.],
                 [8., 5., 3.],
                 [3., 8., 5.]])
standard_scaler = StandardScaler()
new_data = standard_scaler.fit_transform(data)
print(new_data)

実行する。

python3 ss.py

[[ 0.46291005 -1.22474487 -0.70710678]
 [ 0.9258201   0.         -0.70710678]
 [-1.38873015  1.22474487  1.41421356]]

MinMaxScaler

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html#sklearn.preprocessing.MinMaxScaler

指定した範囲にデータをマッピングします。
引数には以下の値がデフォルトで設定されています。

MinMaxScaler(feature_range=(0, 1), copy=True)

feature_rangeはタプルで、(最小値、最大値)のように指定します。
デフォルト値では0から1の間にマッピングされます。
ファイルを作成します。

touch mms.py

以下を記述します。

from sklearn.preprocessing import MinMaxScaler
import numpy as np
data = np.array([[0., 2.],
                 [3., 4.],
                 [10., 7.]])
standard_scaler = MinMaxScaler(feature_range=(0, 1))
new_data = standard_scaler.fit_transform(data)
print(new_data)

実行する。

python3 mms.py

[[ 0.   0. ]
 [ 0.3  0.4]
 [ 1.   1. ]]

実装コードの出力結果から見てわかるように、変換器の入力が2次元配列の場合、各列(axis=0)に対してそれぞれマッピングが行われます。

OneHotEncorder

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html#sklearn.preprocessing.OneHotEncoder

整数値のラベルをone-hotのラベルに変換します。

OneHotEncoder(n_values=’auto’, categorical_features=’all’, dtype=<class ‘numpy.float64’>, sparse=True, handle_unknown=’error’)

ファイルを作成します。

touch ohe.py

以下を記述します。

サンプルコード

from sklearn.preprocessing import OneHotEncoder
import numpy as np
data = np.array([0, 2, 1, 1]).reshape(-1, 1)
one_hot = OneHotEncoder()
new_data = one_hot.fit_transform(data).toarray()
print(new_data)

実行する

python3 ohe.py

[[ 1.  0.  0.]
 [ 0.  0.  1.]
 [ 0.  1.  0.]
 [ 0.  1.  0.]]

おまけ

ラベルが文字列の場合、pandasのSeriesのメソッドfactorize()を使うと、整数値のラベル変換できます。

import pandas as pd
data = pd.Series(["apple", "orange", "banana", "banana"])
new_data, _ = data.factorize()
print(new_data)
>>
[0 1 2 2]

まとめ

前処理は機械学習を行う前に一手間かけるだけで、学習性能をあげることが期待できます。例えばMNIST(0~9までの数字が書かれた画像のデータセット)を10個のクラスへ分類する問題では、StandardScaler変換器を使用して標準化するだけで、精度を80数パーセントから90パーセント代まで引き上げることができました。ぜひ活用して見てください。

スポンサードリンク

「為になったなぁ」と思ったら、シェアお願いします。

-プログラミング
-

執筆者:


comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

関連記事

CakePHPでhelloworld

概要 CakePHPでプログラミングのお約束helloworldを行いましょう。 helloworldとは動作確認も含めた、一番最初に書くプログラムコードです。 だいたいはhello worldと単純 …

phpとmysqlでアカウント管理する時のテストユーザーのパスワードの暗号化はmysqlのsha1でもできる

by brewbooks 概要 みなさんこんにちはcandleです。最近はphpでサービスを書いたりしています。その中でテストユーザーのアカウント管理でパスワードを暗号化してデータベースに収めています …

processingで重複しないランダムな数を配列で取得する

概要 みなさんこんにちはcandleです。今回はprocessingで重複しないランダムな数を配列で取得する関数を作成したいとおもいます。 前提 なし

C++のopencvでhelloworld

概要 みなさんこんにちはcandleです。インストールできたopencvを使ってhelloworldを行いましょう。 opencvでhelloworldとはなんぞや、と感じるかもしれませんが、open …

railsのgonで別ページでリロード後turbolinksで移動したら変数がundefindする場合の対処

概要 (追記 2016/05/18 この方法を行うと、turbolinksで問題が起きました。 turbolinksで移動した回数だけ、javascriptが実行されてしまいました。 例えば、 < % ...

プロフィール


ベンチャー企業のCTOをやってます。大学時代にプログラミングを始め、javaから入門し、C++へて、PHPへと進み、会社ではRailsを使用。自動化が大好きなプログラマー

スポンサードリンク

アーカイブ