機械学習 実践(ハイパーパラメータ調整)
本章では、モデルの予測精度を向上させるための重要な概念となるハイパーパラメータについて基礎的な概要から、具体的な調整方法と調整後の検証方法までの一連手順を解説します。
本章の構成
- ハイパーパラメータの概要と交差検証
- ハイパーパラメータの調整方法
ハイパーパラメータの概要と交差検証
ハイパーパラメータとは
前章までにも何度か単語としては紹介していたハイパーパラメータですが、具体的にどのようなモノを指しているのか説明していきます。
最初にパラメータとの違いを把握しておきましょう。パラメータとは、モデルの学習実行後に獲得される値を指しており、重みとも呼ばれます。
対してハイパーパラメータは、各アルゴリズムに付随して、アルゴリズムの挙動を制御するための値です。モデルの学習実行前にハイパーパラメータを調整することでモデルの性能向上や過学習の抑制、効率の良い学習などが期待できます。具体的に例を出すと、ロジスティック回帰においては (コストパラメータ)、 がロジスティック回帰のハイパーパラメータとなります。
K-分割交差検証 (K-fold cross-validation)
前章までは、与えられたデータセットを学習用データセット・テスト用データセットを 2 分割しました。それをホールドアウト法と呼びました。しかし、実際の開発時にはモデルの性能評価をより適切にするためにデータを 3 分割してモデルを評価することが一般的です。
データ名称 | 使用目的 |
---|---|
学習用データセット (train) | モデルを学習させるためのデータセット |
検証用データセット (validation) | ハイパーパラメータの調整が適切なのか検証するためのデータセット |
テスト用データセット (test) | 学習済みモデルの性能を評価するためのデータセット |
学習用データセットと検証用データセットは学習段階で用いられ、テスト用データセットは最終的なモデルの予測精度の確認のためにのみ使用するということを抑えておきましょう。
しかし、十分なデータ量が用意できない場合には 3 分割すると偏りが生じて適切な学習・検証が行われない可能性があります。そのようなデータの偏りを回避する方法として K-分割交差検証 (K-fold cross-validation) があります。
K-分割交差検証は下記の 3 つのステップから構成されており、視覚的に確認すると分かり易いため、図で解説していきます。
- データセットを 個に分割
- 分割したデータの 1 個を検証用データセットとし、残り 個を学習用データセットとして学習を実行
- 各検証の結果を平均して最終的な検証結果とする
それでは、それぞれのステップを確認していきましょう。前提として、K-分割交差検証は学習用データセットと検証用データセットの分割に用いることが多いです。そのため、下記の図ではテスト用データセットは既に別途分割している事とします。
第 1 ステップとして、データセットを 個に分割します。下記の例では分割数 を 5 にしています。
第 2 ステップとして、分割したデータの 1 個を検証用データセットとし、残り 個を学習用データセットとして学習を実行します。
ここで重要なポイントとして 1 回で学習を終わらせず、計 回の学習を行います。その際、既に検証用データセットに使ったデータを次は学習用データセットとして使用し、新たに検証用データセットを選択します。
第 3 ステップとして、各検証の結果を平均して最終的な検証結果とします。
このようにすれば、データに偏りなくハイパーパラメータのチューニングを行うことができます。
ハイパーパラメータの調整方法
ハイパーパラメータの概要と検証方法について理解できたので、続いて具体的な調整方法を見ていきましょう。調整方法については代表的な方法として以下の 4 つを紹介します。
- 手動での調整
- グリッドサーチ (Grid Search)
- ランダムサーチ (Random Search)
- ベイズ最適化 (Bayesian Optimization)
アルゴリズムには決定木を使用し、上のそれぞれの方法でハイパーパラメータの調整を行います。
手動での調整
最初に手動でハイパーパラメータの調整を行い、予測精度にどのような変化があるのかを確認しましょう。
import numpy as np
import pandas as pd
今回は scikit-learn に準備されている乳がんに関するデータセットを使用します。目標値が陰性か陽性かの 2 つの値である二値分類の問題設定になります。
# 乳がんに関するデータセットの読み込み
from sklearn.datasets import load_breast_cancer
dataset = load_breast_cancer()
t = dataset.target
x = dataset.data
x.shape, t.shape
まずは先ほど紹介した通りデータを学習用データセット・検証用データセット・テスト用データセットの 3 つに分割します。手順としては以下の通りです。
- 与えられたデータを「テスト用データセット:その他 = 20 : 80 」に分割
- 「その他」のデータを「検証用データセット:学習用データセット = 30 : 70 」に分割
データセットの分割割合については、データ数に依存することになるので決まりごとはありませんが、上記のような割合で分割することが一般的に多いです。
from sklearn.model_selection import train_test_split
x_train_val, x_test, t_train_val, t_test = train_test_split(x, t, test_size=0.2, random_state=1)
# 検証用データセット:学習用データセット= 30 : 70
x_train, x_val, t_train, t_val = train_test_split(x_train_val, t_train_val, test_size=0.3, random_state=1)
x_train.shape, x_val.shape, x_test.shape
データセットの準備が整ったので決定木の実装を行いましょう。ハイパーパラメータの調整を行わずに、デフォルトで設定されている値を使用して、学習を行い、予測精度を確認します。
from sklearn.tree import DecisionTreeClassifier
dtree = DecisionTreeClassifier(random_state=0)
dtree.fit(x_train, t_train)
print('train score : ', dtree.score(x_train, t_train))
print('validation score : ', dtree.score(x_val, t_val))
学習用データセットに対して 100%、検証用データセットに対して 92% の予測精度が確認できました。学習用データセットに対しての予測精度が高く、検証用データセットに対しては予測精度が低いという、少し過学習の傾向があることがわかります。
過学習を抑制するハイパーパラメータを調整を行い、再度モデルの学習を行いましょう。DecisionTreeClassifier()
メソッドの引数にハイパーパラメータの設定を記述します。
# ハイパーパラメータを設定して、モデルの定義
dtree = DecisionTreeClassifier(max_depth=10, min_samples_split=30, random_state=0)
dtree.fit(x_train, t_train)
print('train score : ', dtree.score(x_train, t_train))
print('validation score : ', dtree.score(x_val, t_val))
ハイパーパラメータの調整によって先ほどとは異なった結果が得られ、検証用データセットに対して 92% → 95.6% といった予測精度の向上が確認できました。テスト用データセットに対しても予測精度を検証してみましょう。
print('test score :', dtree.score(x_test, t_test))
グリッドサーチ
前節では手動で適当にハイパーパラメータの値を決めました。しかし、適当に入れた値が常に最適なハイパーパラメータである可能性は低いと言えるでしょう。最適なハイパーパラメータを獲得するにはある程度の探索(試行錯誤)を行う必要があります。
効率的に最適なハイパーパラメータを探索する方法はいくつかあり、その内の 1 つがグリッドサーチです。
グリッドサーチはまず、ハイパーパラメータを探索する範囲を決めます。例えば下記の図のように決定木の max_depth
と min_samples_split
の値を調整したい場合、5、10、15、20、25 のように範囲をそれぞれ決めます(範囲の指定に特に決まりはありません)。この場合のハイパーパラメータの組み合わせは 5 x 5 = 25 個になります。この 25 個のハイパーパラメータの組み合わせ全てを使用して、学習・検証を行います。そして、その結果から予測精度が最も高いハイパーパラメータを採用します。
しかし、グリッドサーチにはデメリットも存在します。実装方法を確認するために整理しておきましょう。
- メリット:指定した範囲を網羅するため、ある程度漏れがなくハイパーパラメータの探索を行うことができる
- デメリット:場合によっては、数十~数百パターンの組合せを計算するため学習に時間を要する
グリッドサーチの概要が理解できたので、実装を行います。グリッドサーチの実装は scikit-learn の中で準備されている GridSearchCV
クラスを用いて行います。
# GridSearchCV クラスのインポート
from sklearn.model_selection import GridSearchCV
GridSearchCV
クラスの使用には下記の 3 つを準備する必要があります。
- estimator :学習に使用するモデル
- param_grid :ハイパーパラメータを探索する範囲
- cv :K-分割交差検証の の値
まずは estimator
を定義します。estimator
はこれまでモデルの定義で定義していたモデルを指します。
# 学習に使用するアルゴリズムの定義
estimator = DecisionTreeClassifier(random_state=0)
ハイパーパラメータの探索する範囲を指定します。範囲の指定は辞書型で、調整するハイパーパラメータの名前を Key に、リスト型の探索する範囲を Value に格納します。調整するハイパーパラメータの名前を間違うとエラーになるため、確認して名前を記述するようにしましょう。
# 探索するハイパーパラメータと範囲の定義
param_grid = [{
'max_depth': [3, 20, 50],
'min_samples_split': [3, 20, 30]
}]
# データセット分割数を定義
cv = 5
Grid SearchCV
では K-分割交差検証が行われます。そのため、学習用データセットと検証用データセットに分割する前のデータセットである x_train_val
と t_train_val
を使用します。return_train_score=False
を設定することで学習に対する予測精度の検証が行われません。もし、検証を行う際には True
に変更します。False
にするメリットは計算コストを抑えることにあります。
# GridSearchCV クラスを用いたモデルの定義
tuned_model = GridSearchCV(estimator=estimator,
param_grid=param_grid,
cv=cv, return_train_score=False)
GridSearchCV
クラスでも、これまでと同様に fit()
メソッドでモデルの学習を行うことができます。
# モデルの学習&検証
tuned_model.fit(x_train_val, t_train_val)
学習結果は cv_results_
に保持されています。辞書型で格納されているため、pandas.DataFrame
型に変換して確認すると見やすく表示することができます。また、転置すると実行結果が列方向になり見やすくなります。
# 検証結果の確認
pd.DataFrame(tuned_model.cv_results_).T
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |
---|---|---|---|---|---|---|---|---|---|
mean_fit_time | 0.00451441 | 0.0036581 | 0.00385442 | 0.00448565 | 0.00465679 | 0.0048944 | 0.0047523 | 0.00445013 | 0.00513897 |
std_fit_time | 0.00148618 | 1.10328e-05 | 0.000261475 | 0.00013105 | 0.000651413 | 0.000456741 | 0.00022005 | 0.000304309 | 0.000824373 |
mean_score_time | 0.000360966 | 0.000303745 | 0.000331211 | 0.000352001 | 0.000325346 | 0.000379038 | 0.000391293 | 0.000347948 | 0.000415516 |
std_score_time | 6.73275e-05 | 7.81198e-06 | 5.94357e-05 | 9.04693e-05 | 2.02886e-05 | 8.63773e-06 | 4.49345e-05 | 2.31875e-05 | 7.33988e-05 |
param_max_depth | 3 | 3 | 3 | 20 | 20 | 20 | 50 | 50 | 50 |
param_min_samples_split | 3 | 20 | 30 | 3 | 20 | 30 | 3 | 20 | 30 |
params | {'max_depth': 3, 'min_samples_split': 3} | {'max_depth': 3, 'min_samples_split': 20} | {'max_depth': 3, 'min_samples_split': 30} | {'max_depth': 20, 'min_samples_split': 3} | {'max_depth': 20, 'min_samples_split': 20} | {'max_depth': 20, 'min_samples_split': 30} | {'max_depth': 50, 'min_samples_split': 3} | {'max_depth': 50, 'min_samples_split': 20} | {'max_depth': 50, 'min_samples_split': 30} |
split0_test_score | 0.923077 | 0.912088 | 0.912088 | 0.956044 | 0.912088 | 0.912088 | 0.956044 | 0.912088 | 0.912088 |
split1_test_score | 0.901099 | 0.901099 | 0.901099 | 0.912088 | 0.901099 | 0.901099 | 0.912088 | 0.901099 | 0.901099 |
split2_test_score | 0.934066 | 0.934066 | 0.934066 | 0.923077 | 0.934066 | 0.934066 | 0.923077 | 0.934066 | 0.934066 |
split3_test_score | 0.945055 | 0.934066 | 0.934066 | 0.967033 | 0.945055 | 0.945055 | 0.967033 | 0.945055 | 0.945055 |
split4_test_score | 0.901099 | 0.901099 | 0.901099 | 0.956044 | 0.934066 | 0.901099 | 0.956044 | 0.934066 | 0.901099 |
mean_test_score | 0.920879 | 0.916484 | 0.916484 | 0.942857 | 0.925275 | 0.918681 | 0.942857 | 0.925275 | 0.918681 |
std_test_score | 0.0175824 | 0.0149062 | 0.0149062 | 0.0213085 | 0.0161505 | 0.017855 | 0.0213085 | 0.0161505 | 0.017855 |
rank_test_score | 5 | 8 | 8 | 1 | 3 | 6 | 1 | 3 | 6 |
ハイパーパラメータの種類が 2 つで、各 3 個ずつ値を指定したので 3 × 3 = 9 パターンの計算が行われています。また、 を 5 としたので 5 種類の結果 (split0_test_score ~ split4_test_score
) が出力されています。
それぞれの項目の概要は下記になります。
項目名 | 説明 |
---|---|
mean_fit_time | 学習時間の平均 |
std_fit_time | 学習時間の標準偏差 |
mean_score_time | 検証時間の平均 |
std_score_time | 検証時間の標準偏差 |
param_max_depth | max_depth の値 |
param_min_samples_split | min_samples_split の値 |
params | 調整しているハイパーパラメータの値 |
split0_test_score | 交差検証 1 回目の検証用データセットに対しての予測精度 |
split1_test_score | 交差検証 2 回目の検証用データセットに対しての予測精度 |
split2_test_score | 交差検証 3 回目の検証用データセットに対しての予測精度 |
split3_test_score | 交差検証 4 回目の検証用データセットに対しての予測精度 |
split4_test_score | 交差検証 5 回目の検証用データセットに対しての予測精度 |
mean_test_score | 検証用データセットに対しての予測精度の平均 |
std_test_score | 検証用データセットに対しての予測精度の標準偏差 |
rank_test_score | 検証用データセットに対しての予測精度の順位 |
mean_test_score
の値を確認するとそのモデルの予測精度の確認ができます。基本的にはこの値を確認し、どのハイパーパラメータが効果が強いのかを確認します。
その後、結果を参照して先ほどより狭い範囲でハイパーパラメータを調整します。これを何度か繰り返すことで徐々に予測精度が高くなるハイパーパラメータへと近づけて行きます。
estimator = DecisionTreeClassifier(random_state=0)
cv = 5
param_grid = [{
'max_depth': [5, 10, 15] ,
'min_samples_split': [10, 12, 15]
}]
# モデルの定義
tuned_model = GridSearchCV(estimator=estimator,
param_grid=param_grid,
cv=cv, return_train_score=False)
# モデルの学習
tuned_model.fit(x_train_val, t_train_val)
# 学習結果の確認
pd.DataFrame(tuned_model.cv_results_).T
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |
---|---|---|---|---|---|---|---|---|---|
mean_fit_time | 0.00617938 | 0.00449882 | 0.00466156 | 0.00495281 | 0.00438828 | 0.00438123 | 0.00576015 | 0.00460849 | 0.00444918 |
std_fit_time | 0.00221583 | 0.000329802 | 0.000582937 | 0.000967085 | 0.000203726 | 0.000205368 | 0.0008076 | 0.0005006 | 0.000127528 |
mean_score_time | 0.000523567 | 0.000335741 | 0.000344658 | 0.000356102 | 0.000305843 | 0.000317764 | 0.000461054 | 0.000330877 | 0.000342655 |
std_score_time | 0.000262405 | 2.71234e-05 | 7.0749e-05 | 7.26917e-05 | 9.91936e-06 | 4.66082e-06 | 0.000106698 | 2.38749e-05 | 2.66996e-05 |
param_max_depth | 5 | 5 | 5 | 10 | 10 | 10 | 15 | 15 | 15 |
param_min_samples_split | 10 | 12 | 15 | 10 | 12 | 15 | 10 | 12 | 15 |
params | {'max_depth': 5, 'min_samples_split': 10} | {'max_depth': 5, 'min_samples_split': 12} | {'max_depth': 5, 'min_samples_split': 15} | {'max_depth': 10, 'min_samples_split': 10} | {'max_depth': 10, 'min_samples_split': 12} | {'max_depth': 10, 'min_samples_split': 15} | {'max_depth': 15, 'min_samples_split': 10} | {'max_depth': 15, 'min_samples_split': 12} | {'max_depth': 15, 'min_samples_split': 15} |
split0_test_score | 0.967033 | 0.923077 | 0.912088 | 0.967033 | 0.923077 | 0.912088 | 0.967033 | 0.923077 | 0.912088 |
split1_test_score | 0.912088 | 0.901099 | 0.901099 | 0.912088 | 0.901099 | 0.901099 | 0.912088 | 0.901099 | 0.901099 |
split2_test_score | 0.923077 | 0.934066 | 0.934066 | 0.923077 | 0.934066 | 0.934066 | 0.923077 | 0.934066 | 0.934066 |
split3_test_score | 0.956044 | 0.956044 | 0.945055 | 0.956044 | 0.956044 | 0.945055 | 0.956044 | 0.956044 | 0.945055 |
split4_test_score | 0.967033 | 0.967033 | 0.934066 | 0.967033 | 0.967033 | 0.934066 | 0.967033 | 0.967033 | 0.934066 |
mean_test_score | 0.945055 | 0.936264 | 0.925275 | 0.945055 | 0.936264 | 0.925275 | 0.945055 | 0.936264 | 0.925275 |
std_test_score | 0.0230507 | 0.0234661 | 0.0161505 | 0.0230507 | 0.0234661 | 0.0161505 | 0.0230507 | 0.0234661 | 0.0161505 |
rank_test_score | 1 | 4 | 7 | 1 | 4 | 7 | 1 | 4 | 7 |
グリッドサーチ 2 回目の結果を確認できました。このように、最初はある程度大きな幅を持ってグリッドサーチを行い、徐々に範囲を狭めてより予測精度の高いハイパーパラメータを探していきます。最後にテストデータを用いて、グリッドサーチで学習させたモデルの予測精度を確認しましょう。
# 最も予測精度の高かったハイパーパラメータの確認
tuned_model.best_params_
best_estimator_
で最も検証用データセットに対しての予測精度が最も高かったハイパーパラメータで学習したモデルを取得することができます。取得したモデルを新たに best_model
という変数に格納します。
# 最も予測精度の高かったモデルの引き継ぎ
best_model = tuned_model.best_estimator_
# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))
手動でハイパーパラメータの調整を行ったモデルのテスト用データセットに対する予測精度より精度が向上していることが確認できます。今回はたまたま精度が向上するハイパーパラメータを発見できただけかもしれませんが、手動で地道に調整するよりは分かりやすく調整することができました。
ランダムサーチ
グリッドサーチの 1 つの欠点として、グリッド上にしか探索できないという点にあります。
そこで、ランダムサーチは指定した範囲のハイパーパラメータをランダムに抽出し、学習・検証を行います。この方法により、広い範囲を探索することがより効率的に可能になりました。しかし、もちろん全てのハイパーパラメータを探索するわけではないため、そのハイパーパラメータが最適かは判断が難しい点がランダムサーチの欠点と言えるでしょう。
文献の中では、経験的にグリッドサーチと比較して、ランダムサーチの方が効率的にハイパーパラメータを探索することができるケースもあると説明しているものもあります。ランダムサーチである程度の範囲を絞ったあとに、グリッドサーチで局所的に探索するという方法もあるかもしれません。
実装方法はグリッドサーチと似ているので、すぐに実装することができます。
# RandomizedSearchCV クラスのインポート
from sklearn.model_selection import RandomizedSearchCV
# 学習に使用するアルゴリズム
estimator = DecisionTreeClassifier(random_state=0)
ハイパーパラメータを探索する範囲の指定します。指定方法はグリッドサーチと同様になりますが、今回はランダムサーチの挙動を確認するために、範囲を少し広げて指定します。
範囲の指定に range(開始値, 終了値, ステップ)
を使用します。例えば range(1, 10, 2)
の場合、1 から 10 までの値を 2 刻みで獲得できます。その値を list()
でリスト化しています。
list(range(1, 10, 2))
# ハイパーパラメータを探索する範囲の指定
param_distributions = {
'max_depth': list(range(5, 100, 2)),
'min_samples_split': list(range(2, 50, 1))
}
ランダムサーチはグリッドサーチと異なり、指定した範囲のハイパーパラメータをランダムに抽出し学習を行うため、何回学習を試行するかの回数を指定する必要があります。
# 試行回数の指定
n_iter = 100
RandomizedSearchCV
クラスでも K-分割交差検証が行われるため、 の値を指定します。
cv = 5
ランダムにハイパーパラメータが抽出されるため、再現性の確保のために乱数のシードの固定を行い、定義した値を用いてモデルの定義を行いましょう。
# モデルの定義
tuned_model = RandomizedSearchCV(
estimator=estimator,
param_distributions=param_distributions,
n_iter=n_iter, cv=cv,
random_state=0, return_train_score=False
)
# モデルの学習&検証
tuned_model.fit(x_train_val, t_train_val)
今回試行回数を 100 回に設定しているため、学習結果を検証用データセットに対しての順位を表す rank_test_score
の値を基準に昇順に並び替えて表示します。
# 学習結果の確認(スコアの高い順に表示)
pd.DataFrame(tuned_model.cv_results_).sort_values('rank_test_score').T
47 | 77 | 82 | 90 | 42 | 19 | 28 | 12 | 11 | 62 | 69 | 39 | 70 | 3 | 96 | 29 | 6 | 68 | 43 | 34 | 9 | 48 | 45 | 33 | 91 | 32 | 25 | 37 | 44 | 46 | 36 | 52 | 54 | 57 | 59 | 61 | 63 | 66 | 76 | 75 | ... | 97 | 78 | 83 | 92 | 79 | 89 | 88 | 80 | 81 | 87 | 86 | 85 | 93 | 0 | 49 | 71 | 2 | 5 | 7 | 16 | 17 | 20 | 21 | 22 | 23 | 26 | 72 | 30 | 35 | 38 | 40 | 41 | 98 | 50 | 55 | 58 | 60 | 67 | 31 | 99 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
mean_fit_time | 0.004392 | 0.00440807 | 0.00442195 | 0.00440397 | 0.00444493 | 0.00448847 | 0.00438666 | 0.0044693 | 0.00443282 | 0.00452209 | 0.00441437 | 0.00442162 | 0.00441089 | 0.00449786 | 0.00450063 | 0.00441008 | 0.00447145 | 0.00439453 | 0.00443196 | 0.00442343 | 0.00473528 | 0.0044848 | 0.00441575 | 0.00437322 | 0.00435953 | 0.00441637 | 0.00489454 | 0.00438333 | 0.00436053 | 0.00434899 | 0.00436254 | 0.00435524 | 0.00440626 | 0.0044055 | 0.00433431 | 0.00440831 | 0.00437479 | 0.00434923 | 0.00437369 | 0.00434961 | ... | 0.00434093 | 0.00425792 | 0.00437932 | 0.00426431 | 0.00434828 | 0.00431495 | 0.00426273 | 0.00434732 | 0.00426693 | 0.00426288 | 0.00428667 | 0.00425739 | 0.00425711 | 0.00691209 | 0.00427871 | 0.00432868 | 0.00471344 | 0.00429893 | 0.00482011 | 0.00429759 | 0.00441604 | 0.0043901 | 0.00432558 | 0.00433555 | 0.00429206 | 0.00430884 | 0.00432935 | 0.00427389 | 0.00431142 | 0.00428219 | 0.00430303 | 0.00431805 | 0.00429115 | 0.00429873 | 0.0042963 | 0.00429978 | 0.00429058 | 0.00426645 | 0.00429425 | 0.00428686 |
std_fit_time | 0.000216881 | 0.000210771 | 0.000191937 | 0.00020275 | 0.000203917 | 0.000221445 | 0.000217712 | 0.000246065 | 0.000200565 | 0.000150701 | 0.000194392 | 0.000210716 | 0.000201556 | 0.000211967 | 0.00018613 | 0.000173295 | 0.000204783 | 0.000191711 | 0.000199555 | 0.00018899 | 0.000478904 | 0.000168038 | 0.000193723 | 0.000215826 | 0.000210294 | 0.000239198 | 0.000996196 | 0.000211181 | 0.000174822 | 0.000204838 | 0.000231535 | 0.000184793 | 0.000165923 | 0.000187781 | 0.000221183 | 0.000238855 | 0.000245671 | 0.000270185 | 0.000231199 | 0.000227636 | ... | 0.000230947 | 0.000232963 | 0.000156944 | 0.000229298 | 0.000245686 | 0.000241585 | 0.00022761 | 0.000230197 | 0.000245143 | 0.000228639 | 0.000224447 | 0.000232696 | 0.000237102 | 0.00513629 | 0.000227837 | 0.000218776 | 0.000504718 | 0.0002491 | 0.00106995 | 0.000265608 | 0.000296536 | 0.000241779 | 0.000231495 | 0.000270263 | 0.000254026 | 0.000229142 | 0.000249078 | 0.000238116 | 0.000210643 | 0.000230255 | 0.000199076 | 0.000244708 | 0.000264486 | 0.000247363 | 0.000239105 | 0.000221817 | 0.000219544 | 0.000229738 | 0.000208532 | 0.000258263 |
mean_score_time | 0.00030756 | 0.000301409 | 0.000318384 | 0.000304747 | 0.000336647 | 0.000322866 | 0.000305033 | 0.000314856 | 0.000302267 | 0.000347471 | 0.000293589 | 0.000307703 | 0.000308561 | 0.000346327 | 0.000348711 | 0.000296879 | 0.000339651 | 0.000315046 | 0.000306749 | 0.00031004 | 0.000354576 | 0.000327969 | 0.000300312 | 0.000312948 | 0.000304985 | 0.000312328 | 0.000317097 | 0.000318146 | 0.000310612 | 0.00029707 | 0.000302744 | 0.000309086 | 0.000330544 | 0.000315428 | 0.000310278 | 0.000326061 | 0.000322723 | 0.000314856 | 0.000314045 | 0.000301933 | ... | 0.000316668 | 0.000296688 | 0.000299883 | 0.000301647 | 0.000321722 | 0.000305653 | 0.000308275 | 0.000326109 | 0.000299263 | 0.000298214 | 0.000301361 | 0.000298595 | 0.000304317 | 0.000374985 | 0.000301504 | 0.000317144 | 0.000370216 | 0.000312042 | 0.000481749 | 0.000304985 | 0.000351667 | 0.000337362 | 0.000312901 | 0.000326395 | 0.000310946 | 0.000310659 | 0.00031209 | 0.000294971 | 0.00029974 | 0.000309038 | 0.000324154 | 0.000314474 | 0.00031395 | 0.000304556 | 0.000319052 | 0.000323153 | 0.000325298 | 0.000312138 | 0.000304937 | 0.000312996 |
std_score_time | 1.02126e-05 | 5.17626e-06 | 2.10849e-05 | 8.58757e-06 | 1.41998e-05 | 8.48904e-06 | 1.07173e-05 | 2.06964e-05 | 9.9056e-06 | 1.40263e-05 | 8.42748e-06 | 1.30156e-05 | 1.48979e-05 | 1.99273e-05 | 1.48476e-05 | 5.80372e-06 | 2.85516e-05 | 1.64661e-05 | 9.41267e-06 | 7.09126e-06 | 6.03265e-05 | 1.84559e-05 | 1.03898e-05 | 9.47047e-06 | 1.98134e-05 | 6.9494e-06 | 1.95474e-05 | 1.14099e-05 | 8.45818e-06 | 6.11022e-06 | 1.16512e-05 | 6.45335e-06 | 1.94959e-05 | 1.19571e-05 | 1.06982e-05 | 1.59114e-05 | 5.07468e-06 | 1.12291e-05 | 1.16063e-05 | 1.8022e-05 | ... | 1.17793e-05 | 2.84748e-06 | 6.63369e-06 | 1.02508e-05 | 2.06282e-05 | 1.9187e-05 | 1.37978e-05 | 2.31972e-05 | 8.85095e-06 | 6.37323e-06 | 8.0358e-06 | 4.36143e-06 | 1.06481e-05 | 4.33886e-05 | 1.07353e-05 | 1.89265e-05 | 4.60918e-05 | 1.98535e-05 | 0.000283498 | 7.31817e-06 | 2.47487e-05 | 5.81078e-05 | 6.04587e-06 | 3.03884e-05 | 7.7424e-06 | 9.29282e-06 | 1.2362e-05 | 5.14763e-06 | 6.62684e-06 | 5.73078e-06 | 1.92104e-05 | 1.2833e-05 | 2.00461e-05 | 5.21608e-06 | 8.50777e-06 | 1.3258e-05 | 1.39352e-05 | 1.99791e-05 | 7.8945e-06 | 2.09848e-05 |
param_min_samples_split | 10 | 10 | 4 | 4 | 7 | 9 | 11 | 2 | 8 | 7 | 4 | 2 | 2 | 2 | 4 | 6 | 8 | 4 | 9 | 5 | 5 | 5 | 5 | 13 | 12 | 12 | 12 | 13 | 14 | 16 | 14 | 14 | 24 | 14 | 20 | 16 | 23 | 23 | 15 | 16 | ... | 29 | 39 | 44 | 36 | 27 | 35 | 36 | 31 | 48 | 43 | 31 | 39 | 42 | 30 | 38 | 27 | 37 | 40 | 36 | 40 | 39 | 27 | 27 | 43 | 41 | 27 | 30 | 42 | 27 | 43 | 49 | 31 | 45 | 27 | 43 | 36 | 36 | 47 | 44 | 39 |
param_max_depth | 23 | 65 | 95 | 39 | 15 | 37 | 7 | 87 | 29 | 7 | 9 | 21 | 97 | 89 | 41 | 65 | 25 | 47 | 35 | 59 | 87 | 29 | 13 | 73 | 5 | 31 | 55 | 35 | 11 | 77 | 15 | 49 | 7 | 53 | 91 | 45 | 91 | 95 | 69 | 61 | ... | 89 | 27 | 61 | 39 | 81 | 89 | 17 | 73 | 15 | 67 | 27 | 37 | 71 | 9 | 9 | 45 | 63 | 95 | 59 | 11 | 25 | 27 | 37 | 73 | 55 | 19 | 79 | 93 | 35 | 49 | 87 | 23 | 19 | 99 | 27 | 27 | 47 | 75 | 95 | 87 |
params | {'min_samples_split': 10, 'max_depth': 23} | {'min_samples_split': 10, 'max_depth': 65} | {'min_samples_split': 4, 'max_depth': 95} | {'min_samples_split': 4, 'max_depth': 39} | {'min_samples_split': 7, 'max_depth': 15} | {'min_samples_split': 9, 'max_depth': 37} | {'min_samples_split': 11, 'max_depth': 7} | {'min_samples_split': 2, 'max_depth': 87} | {'min_samples_split': 8, 'max_depth': 29} | {'min_samples_split': 7, 'max_depth': 7} | {'min_samples_split': 4, 'max_depth': 9} | {'min_samples_split': 2, 'max_depth': 21} | {'min_samples_split': 2, 'max_depth': 97} | {'min_samples_split': 2, 'max_depth': 89} | {'min_samples_split': 4, 'max_depth': 41} | {'min_samples_split': 6, 'max_depth': 65} | {'min_samples_split': 8, 'max_depth': 25} | {'min_samples_split': 4, 'max_depth': 47} | {'min_samples_split': 9, 'max_depth': 35} | {'min_samples_split': 5, 'max_depth': 59} | {'min_samples_split': 5, 'max_depth': 87} | {'min_samples_split': 5, 'max_depth': 29} | {'min_samples_split': 5, 'max_depth': 13} | {'min_samples_split': 13, 'max_depth': 73} | {'min_samples_split': 12, 'max_depth': 5} | {'min_samples_split': 12, 'max_depth': 31} | {'min_samples_split': 12, 'max_depth': 55} | {'min_samples_split': 13, 'max_depth': 35} | {'min_samples_split': 14, 'max_depth': 11} | {'min_samples_split': 16, 'max_depth': 77} | {'min_samples_split': 14, 'max_depth': 15} | {'min_samples_split': 14, 'max_depth': 49} | {'min_samples_split': 24, 'max_depth': 7} | {'min_samples_split': 14, 'max_depth': 53} | {'min_samples_split': 20, 'max_depth': 91} | {'min_samples_split': 16, 'max_depth': 45} | {'min_samples_split': 23, 'max_depth': 91} | {'min_samples_split': 23, 'max_depth': 95} | {'min_samples_split': 15, 'max_depth': 69} | {'min_samples_split': 16, 'max_depth': 61} | ... | {'min_samples_split': 29, 'max_depth': 89} | {'min_samples_split': 39, 'max_depth': 27} | {'min_samples_split': 44, 'max_depth': 61} | {'min_samples_split': 36, 'max_depth': 39} | {'min_samples_split': 27, 'max_depth': 81} | {'min_samples_split': 35, 'max_depth': 89} | {'min_samples_split': 36, 'max_depth': 17} | {'min_samples_split': 31, 'max_depth': 73} | {'min_samples_split': 48, 'max_depth': 15} | {'min_samples_split': 43, 'max_depth': 67} | {'min_samples_split': 31, 'max_depth': 27} | {'min_samples_split': 39, 'max_depth': 37} | {'min_samples_split': 42, 'max_depth': 71} | {'min_samples_split': 30, 'max_depth': 9} | {'min_samples_split': 38, 'max_depth': 9} | {'min_samples_split': 27, 'max_depth': 45} | {'min_samples_split': 37, 'max_depth': 63} | {'min_samples_split': 40, 'max_depth': 95} | {'min_samples_split': 36, 'max_depth': 59} | {'min_samples_split': 40, 'max_depth': 11} | {'min_samples_split': 39, 'max_depth': 25} | {'min_samples_split': 27, 'max_depth': 27} | {'min_samples_split': 27, 'max_depth': 37} | {'min_samples_split': 43, 'max_depth': 73} | {'min_samples_split': 41, 'max_depth': 55} | {'min_samples_split': 27, 'max_depth': 19} | {'min_samples_split': 30, 'max_depth': 79} | {'min_samples_split': 42, 'max_depth': 93} | {'min_samples_split': 27, 'max_depth': 35} | {'min_samples_split': 43, 'max_depth': 49} | {'min_samples_split': 49, 'max_depth': 87} | {'min_samples_split': 31, 'max_depth': 23} | {'min_samples_split': 45, 'max_depth': 19} | {'min_samples_split': 27, 'max_depth': 99} | {'min_samples_split': 43, 'max_depth': 27} | {'min_samples_split': 36, 'max_depth': 27} | {'min_samples_split': 36, 'max_depth': 47} | {'min_samples_split': 47, 'max_depth': 75} | {'min_samples_split': 44, 'max_depth': 95} | {'min_samples_split': 39, 'max_depth': 87} |
split0_test_score | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.956044 | 0.967033 | 0.967033 | 0.967033 | 0.956044 | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.923077 | 0.923077 | 0.923077 | 0.923077 | 0.923077 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | ... | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 |
split1_test_score | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.901099 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | ... | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 |
split2_test_score | 0.923077 | 0.923077 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.923077 | 0.923077 | 0.912088 | 0.912088 | 0.912088 | 0.923077 | 0.923077 | 0.923077 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.912088 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | ... | 0.934066 | 0.945055 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.934066 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 |
split3_test_score | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.956044 | 0.956044 | 0.956044 | 0.967033 | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.956044 | 0.956044 | 0.967033 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | 0.945055 | ... | 0.945055 | 0.934066 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.945055 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.945055 | 0.934066 | 0.945055 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 |
split4_test_score | 0.967033 | 0.967033 | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.967033 | 0.956044 | 0.967033 | 0.967033 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.956044 | 0.967033 | 0.956044 | 0.956044 | 0.956044 | 0.956044 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.967033 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | 0.934066 | ... | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 | 0.901099 |
mean_test_score | 0.945055 | 0.945055 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.942857 | 0.940659 | 0.940659 | 0.940659 | 0.940659 | 0.936264 | 0.936264 | 0.936264 | 0.936264 | 0.936264 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | 0.925275 | ... | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 | 0.918681 |
std_test_score | 0.0230507 | 0.0230507 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0263736 | 0.0213085 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0213085 | 0.0213085 | 0.0213085 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0254414 | 0.0236711 | 0.0236711 | 0.0236711 | 0.0236711 | 0.0234661 | 0.0234661 | 0.0234661 | 0.0234661 | 0.0234661 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | 0.0161505 | ... | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 | 0.017855 |
rank_test_score | 1 | 1 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 20 | 20 | 20 | 20 | 24 | 24 | 24 | 24 | 24 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | 29 | ... | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 | 60 |
15 rows × 100 columns
params
の値を確認しましょう。それぞれのハイパーパラメータがランダムに組み合わせられていることが確認できます。検証用データセットに対しての最も予測精度が高かったモデルを取得し、テスト用データセットに対しての予測精度を確認しましょう。
# 最も予測精度の高かったハイパーパラメータの確認
tuned_model.best_params_
# 最も予測精度の高かったモデルの引き継ぎ
best_model = tuned_model.best_estimator_
# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))
ランダムサーチは前述の通り、指定したハイパーパラメータを網羅していないので完全とは言えないですが、どこに予測精度が高くなるハイパーパラメータがあるのかあたりをつける目的では非常に有用です。
ランダムサーチで大体のいい予測精度に繋がるハイパーパラメータのあたりをつけ、グリッドサーチを用いてより詳細な探索を行うという方法もよく用いられる方法の 1 つになります。それぞれのハイパーパラメータの調整方法には長所と短所があることを理解しておきましょう。
ベイズ最適化
最後に紹介するのはベイズ最適化です。この手法はこれまで紹介してきた手法より数学的背景の理解が難しいため、厳密な説明は省略します。 詳細は英語になってしまいますが、こちらを参照してください。
ベイズ最適化では、事前分布と事後分布と呼ばれる確率統計の理論を使用してハイパーパラメータの探索を行います。その際、探索と活用と呼ばれる試行錯誤を繰り返します。イメージとしては人間が行う試行錯誤に近いものがあります。
もしもハイパーパラメータ探索を手動で行う際、まず初めに適当な値(初期値)を入れて結果を確認します。そして、もう一度適当な値を入れて 1 度目の予測精度と比較し、次の探索する場所を決めていくと思います。このように今まであまり調べていない未知の領域に対して適当に値を当てはめることを探索、探索により得た情報を元に設定した指標が最小値(最大値)になると期待できるハイパーパラメータを選択することを活用と呼びます。
探索と活用をまとめると下記のように表現することができます。
- 探索:まだ試していない値の範囲でハイパーパラメータを更新して、予測精度がどう変化するか情報を得る
- 活用:探索で得られた情報をもとに、予測精度が高まる可能性が高い範囲にハイパーパラメータを更新する
ランダムサーチでは、ランダムにハイパーパラメータの値を抽出し学習を行いましたが、ベイズ最適化では探索や活用で得られた情報を元にハイパーパラメータを調整していくため、より効率的に予測精度が高くなるハイパーパラメータを見つけることができると言われています。
本チュートリアルでは、ベイズ最適化を実装するためには日本の Prefferd Networks 社が開発している Optuna というフレームワークを使用します。 Optuna に関しての詳細は公式ページを参照してください。実装時のオプションの詳細などに関しては公式ドキュメントを確認してもらえるとより理解が深まると思います。
Colab には Optuna はインストールされていないため、下記のコマンドを実行してインストールを行います。その他のパッケージも基本的には下記のように pip install パッケージ名
でインストールできることも覚えておきましょう。
# optuna のインストール
!pip install optuna
import optuna
Optuna では最初に関数 objective
を定義して内部に以下の要素を関数として順に定義します。
- ハイパーパラメータごとに探索範囲を指定
- 学習に使用するアルゴリズムを指定
- 学習の実行、検証結果の表示
探索範囲の指定にはデフォルトで準備されている trial
クラスを使用します。
3 では学習・検証を繰り返してハイパーパラメータの調整を行うのですが、その際に return
で取得した検証結果を最小化(最大化)するように調整が進みます。
また、3 で K-分割交差検証を使用するには cross_val_score
が必要である点も認識しておきましょう。
from sklearn.model_selection import cross_val_score
def objective(trial, x, t, cv):
# ① ハイパーパラメータごとに探索範囲を指定
max_depth = trial.suggest_int('max_depth', 2, 100)
min_samples_split = trial.suggest_int('min_samples_split', 2, 100)
# ② 学習に使用するアルゴリズムを指定
estimator = DecisionTreeClassifier(
max_depth = max_depth,
min_samples_split = min_samples_split
)
# ③ 学習の実行、検証結果の表示
print('Current_params : ', trial.params)
accuracy = cross_val_score(estimator, x, t, cv=cv).mean()
return accuracy
準備が整ったらハイパーパラメータの調整を行います。デフォルトでは最小化を行うようになっていますのが、今回は正解率の最大化を目的としています。最大化を目的とする場合には、direction='maximize'
を指定しましょう。
# study オブジェクトの作成(最大化)
study = optuna.create_study(direction='maximize')
# K 分割交差検証の K
cv = 5
# 目的関数の最適化
study.optimize(lambda trial: objective(trial, x_train_val, t_train_val, cv), n_trials=10)
print(study.best_trial)
print
で出力している値はハイパーパラメータの値になります。学習が完了するたびに、現在の正解率を表す resulted in value
と現在までの最も良かった正解率を表示しています。学習が終了したので、最も予測精度の高かったハイパーパラメータを確認するために study.best_params
を実行します。
# 最も予測精度の高かったハイパーパラメータの確認
study.best_params
Optuna でのハイパーパラメータ調整は先ほどと異なり、最も予測精度の高かったハイパーパラメータのみが取得でき、学習済みモデルは取得することができないため、再度学習を行う必要があります。
下記のように **
のようにアスタリスクを 2 つ付け、先程のハイパーパラメータをモデルのインスタンス化を行う際に引数に渡すことで、ハイパーパラメータを設定することができます。
# 最適なハイパーパラメータを設定したモデルの定義
best_model = DecisionTreeClassifier(**study.best_params)
# モデルの学習
best_model.fit(x_train_val, t_train_val)
# モデルの検証
print(best_model.score(x_train_val, t_train_val))
print(best_model.score(x_test, t_test))
ベイズ最適化を用いてハイパーパラメータ調整が行うことができました。それぞれの手法を引き出しとしてもち、それぞれの長所・短所を踏まえた上で手法を選択できるようにしましょう。