feature_importances_ について調べてみた
経緯
授業でランダムフォレストを使ってみる課題が出たが,いまいち feature_importances_ の算出方法がわからずモニョったので調べてみた.
想定する読者
- sklearn を使う人
- 決定木がなんとなくわかる人
- ランダムフォレストが主に分類に用いられるアルゴリズムであると知っている人
sklearn.ensemble.RandomForestClassifierのfeature_importances_ について
概要
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
model.fit(train_data, train_label)
importances = model.feature_importances_
すっごく色々はしょるが,こうすると特徴量ごとの重要度が出てくる.
詳しくはこの辺のサイトを参考にしてください.
feature_importances_ の算出方法
では,この feature_importances_ はどのようにして算出されているのだろうか?
これを知るにはまずランダムフォレストのアルゴリズムをおさらいする必要がある.
決定木のアルゴリズム
ランダムフォレストは決定木の集まりでできている.決定木とは以下のようなもののことである.
決定木そのものは訓練データを用いて構築される.
あるノードにおいて分類にもっとも効果的な特徴量を選択し,データ集合の乱雑さが0になるまで,つまりそのノードにおける全てのデータが同じラベルを持つまで,順次ノードを作っていく.ここで言う効果的な特徴量とは,その特徴量によってデータ集合を分割した際にクラスラベルの乱雑さがもっとも減少するものを差す.乱雑さは gini係数やエントロピーなど様々な式で定義でき,sklearn ではデフォルトで gini係数を用いている.
sklearn.tree.DecisionTreeClassifier — scikit-learn 0.19.1 documentation
ランダムフォレストのアルゴリズム
決定木は過学習を起こしやすいとされ,これを避けるのがランダムフォレストである.
ランダムフォレストではまず,訓練データからいくつかの部分データ集合をランダムサンプリングによって生成する.この部分データ集合ごとに決定木を構築することを考える.ただしあるノードにおいて特徴量は,全ての特徴量のうちランダムにサンプリングされたものの中から選択される.こうすることによって全ての決定木が同じような挙動をすることを避ける.
こうして得られた決定木それぞれが分類を行い,その投票によって分類結果が決定される.
feature_importances_ の算出方法
決定木ではある特徴量による分類の前後で乱雑さがどれほど減少するかで特徴量の選定を行っていた.この減少幅を利得と言うことにする.利得は木の構築時に計算されていることになる.
ざっくり言えば,feature_importances_ はこの利得の特徴量ごとの平均である.ただし,決定木の構築に使われたデータのうちいくつのデータがそのノードへ到達したかで重み付けがなされている.たくさんのデータをさばくノードは重要度が高くなると考えれば,この定義は直感的に納得がいく.
詳しい説明はこのサイトを参考にしてください.ベストアンサーの Gilles Louppe 氏は sklearn の開発メンバーの1人です.
あとがき
今回は sklearn.ensemble.RandomForestClassifier の feature_importances_ の算出方法を調べた.ランダムフォレストをちゃんと理解したら自明っちゃ自明な算出だった.今までランダムフォレストをなんとなくのイメージでしか認識していなかったことが浮き彫りなった.この執筆を通してランダムフォレストを分かった気になれたのでスッキリですわ.
ところで開発者本人から回答が来るって羨ましすぎるよ...
ここまで読んでいただきありがとうございました.私の理解が足りていないところなどがあれば,なにとぞ優しくまさかりを投げてください.ちょっとした質問などもいただけると嬉しいです.
今回の執筆にあたって,一緒に調べたり考えたりしてくれた同研究室のT氏にはお世話になりました.
参考にさせていただいたサイト
scikit learn - How are feature_importances in RandomForestClassifier determined? - Stack Overflow
Random Forestで計算できる特徴量の重要度 - なにメモ
http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html