Python データ分析 プログラミング

【Python】Scikit-learnで作成した決定木のモデルを可視化する

この記事は約4分で読めます。

 

こんにちは、ミナピピン(@python_mllover)です。

 

先日クラウド環境でデータ分析を行って決定木の結果を可視化したいと思ったのですが、graphvizがインストールできないという自体に遭遇しました。その回避法としてgraphviz以外で作成した決定木を可視化方法を探していたところ、plot_tree()という関数で可視化できたので手順や条件をメモ的にまとめておこうと思います。

 

Scikit-learnのバージョン確認

 

plot_tree()を使用するためにはバージョン0.21以上である必要があります。0.20とかだと以下のようなエラーが発生します。

 

ImportError: cannot import name 'plot_tree'

 

scikit-learnのバージョンは以下の手順で確認できます.

 

import sklearn
print(sklearn.__version__)

 

0.21以上が必要なので、0.21未満の場合はpipでアップデートする必要があります。

$ pip install scikit-learn==0.21.2

 

参照:https://github.com/scikit-learn/scikit-learn/issues/13890

 

前準備

 

 

csvの中身は欠損値や特徴量を加工したタイタニックの乗客データです。特徴量エンジニアリングに関しては今回の話題ではないのでサクッと飛ばします。ダウンロードしたcsvは任意の場所に配置して、下記のpd.read_csvのパスを弄って読み込めるようにしておいてください。

 

# ライブラリの読み込み
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import xgboost as xgb


#データの読み込み(パスは自分の環境に合わせて変えてください)
train = pd.read_csv('~/titanic_train_mod.csv.csv')
test= pd.read_csv('~/titanic_test_mod.csv.csv')


#学習データを説明変数と目的変数に分ける

data = train[["Pclass", "Sex", "Age", "Fare"]].values
target = train['Survived'].values

# データを訓練データと正解データに7:3で分ける

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, shuffle=True)

 

参照:【Python】元最強アルゴリズム「XGBoost」で機械学習をやってみた

 

スポンサーリンク
スポンサーリンク

plot_tree()で決定木を可視化する

 

# 決定木分析
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier 
from sklearn.tree import plot_tree

# グラフの大きさを調整する
plt.figure(figsize=(12,12))

clf = DecisionTreeClassifier(random_state=0, max_depth=3)
clf = clf.fit(train_X, train_y)
plot_tree(
    clf,
    feature_names=train_X.columns,
    filled=True, rounded=True
)

plt.show()
plt.savefig("tree1.png")

 

 

今回入れたのはkaggleで超有名なタイタニックの生存者予測のデータで、生存した要因をこういう感じに可視化することができます。分岐の↓はtrue,falseの順です。

 

分岐次第ではグラフが大きすぎたり小さすぎたりするのでplt.figure()で都度調整してください

とりあえずgraphviz無しでも可視化はできます。

実務で使うのは怪しいデザインですが、まあ趣味で結果を確認したい分にはこれで十分かなと思います。

では~

 

 

関連記事:【Python】Window10でGaraphVizがインポートできないエラーの対処法

関連記事:【Python】CHAIDの決定木を実装して結果を可視化する

 

 

コメント

タイトルとURLをコピーしました