Python データ分析 プログラミング

【Python】Scilit-learnで作成した決定木を可視化する

この記事は約3分で読めます。

 

こんにちは、ミナピピン(@python_mllover)です。

 

先日クラウド環境でデータ分析を行って決定木の結果を可視化したいと思ったのですが、graphvizがインストールできないという自体に遭遇しました。その回避法としてgraphviz以外で作成した決定木を可視化方法を探していたところ、plot_tree()という関数で可視化できたので手順や条件をメモ的にまとめておこうと思います。

 

バージョン確認

 

plot_tree()を使用するためにはバージョン0.21以上である必要があります。0.20とかだと以下のようなエラーが発生します。

 

ImportError: cannot import name 'plot_tree'

 

scikit-learnのバージョンは以下の手順で確認できます.

 

import sklearn
print(sklearn.__version__)

 

0.21以上が必要なので、0.21未満の場合はpipでアップデートする必要があります。

$ pip install scikit-learn==0.21.2

 

参照:https://github.com/scikit-learn/scikit-learn/issues/13890

 

 

スポンサーリンク
スポンサーリンク

plot_treeで可視化

 

# 決定木分析
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier 
from sklearn.tree import plot_tree

# グラフの大きさを調整する
plt.figure(figsize=(12,12))

clf = DecisionTreeClassifier(random_state=0, max_depth=3)
clf = clf.fit(train_X, train_y)
plot_tree(
    clf,
    feature_names=train_X.columns,
    filled=True, rounded=True
)

plt.show()
plt.savefig("tree1.png")

 

 

今回入れたのはタイタニックのデータでこういう感じに可視化することができます。

分岐の↓はtrue,falseの順です。

分岐次第ではグラフが大きすぎたり小さすぎたりするのでplt.figure()で都度調整してください

とりあえずgraphviz無しでも可視化はできます。

実務で使うのは怪しいデザインですが、まあ趣味で結果を確認したい分にはこれで十分かなと思います。

では~

 

 

関連記事:【Python】Window10でGaraphVizがインポートできないエラーの対処法

関連記事:【Python】CHAIDの決定木を実装して結果を可視化する

 

コメント

タイトルとURLをコピーしました