決定木の可視化において、とても柔軟性が高いggpartyパッケージをご紹介します。

ggpartyパッケージは、ggplot2の機能をpartykitに拡張し、partyクラスのツリーオブジェクトのために明瞭に構造化され、高度にカスタマイズ可能なビジュアライゼーションを作成するために必要なツールを提供します。

ggpartyパッケージを用いると、ノードやエッジに対して様々な設定をすることで多様な表現が可能になります。
特に強力なのは最下段のノードにおいて、分類木の場合は通常であれば目的変数の100%積み上げ棒グラフとなりますが、個数の積み上げ棒グラフにしたり、ある説明変数でグループ化なども可能になります。
同様に、回帰木の場合は通常であれば目的変数の箱ひげ図となりますが、散布図にしたり、ある説明変数でグループ化なども可能になります。

ここでは簡単のため、rpartパッケージのrpart関数の結果オブジェクトをpartykitパッケージのas.party関数を用いてpartyクラスのツリーオブジェクトに変換したものを用います。

分類木

サンプルデータとして、Rに標準で含まれているTitanicを用います。
最初に必要なパッケージのロードとサンプルデータを扱いやすい形に変更しておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

tmp <- data.frame(Titanic)
df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)

head関数を用いてdfを確認しておきます。


head(df)

結果は次になります。


  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

次に、Survivedを目的変数、ClassとSex、Ageを説明変数として分類木をrpart関数を用いて実行し、結果をctに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


ct <- rpart(Survived ~ Class + Sex + Age , data = df)
pct <- as.party(ct)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

回帰木

サンプルデータとして、rpartパッケージに含まれているcu.summaryを用います。
最初に必要なパッケージのロードしておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

head関数を用いてcu.summaryを確認しておきます。


head(cu.summary)

結果は次になります。


> head(cu.summary)
                Price Country Reliability Mileage  Type
Acura Integra 4 11950   Japan Much better      NA Small
Dodge Colt 4     6851   Japan              NA Small
Dodge Omni 4     6995     USA  Much worse      NA Small
Eagle Summit 4   8895     USA      better      33 Small
Ford Escort   4  7402     USA       worse      33 Small
Ford Festiva 4   6319   Korea      better      37 Small

次に、Priceを目的変数、MileageとCountry、Typeを説明変数として回帰木をrpart関数を用いて実行し、結果をrtに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


rt <- rpart(Price ~ Mileage + Country + Type, data = cu.summary)
prt <- as.party(rt)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。
ここでは、表現力の自由度を確認するために、最下段のノードにおいて、Type別に箱ひげ図を表示するように設定しました。


g <- ggparty(prt, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 3)
g <- g + geom_node_plot(
  gglist = list(geom_boxplot(aes(x = "", y = Price, fill = Type)), theme_bw(base_size = 12)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 10,
    col = "black",
    fontface = "bold"
  ),
  list(size = 12)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 3,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

最後に

ggpartyパッケージを簡単にご紹介させていただきました。
ここでは、棒グラフgeom_barや箱ひげ図geom_boxplotを用いましたが、散布図やヒストグラムなども用いることが可能です。
これらはggplot2パッケージの関数ですので、ggplot2の扱いに慣れていれば、ここでご紹介させていただいたグラフよりも強力な可視化が可能になります。
決定木の可視化にお役に立てたならば幸いです。

参考

関連する記事

  • 決定木 – 分類木決定木 – 分類木 決定木とは、分類ルールを木構造で表したものである。分類したいデータを目的変数(従属変数)、分類するために用いるデータを説明変数(独立変数)という。目的変数がカテゴリデータなどの場合は「分類木」、連続値などの量的データの場合は「回帰木」と呼ばれる。 決定木の最大のメリットは、結果にグラフを用いることができるため、視覚的に確認できることである。 ここでは、R言語の「r […]
  • R ggplot2を用いて散布図と周辺分布をプロットする方法R ggplot2を用いて散布図と周辺分布をプロットする方法 ggplot2を用いて散布図と周辺分布をプロットする2つの方法をお伝えします。 最初の方法は、ggExtraパッケージのggMarginal関数を用いる方法で、周辺分布を簡単にプロットすることができます。 二番目の方法は、散布図と周辺分布を作成した上で、一つにまとめる方法です。 それぞれ一長一短があります。 最初の方法は、コード量が少ないですがグラフとして […]
  • R dplyrパッケージで複数の列を文字列として指定し結合された列を追加する方法R dplyrパッケージで複数の列を文字列として指定し結合された列を追加する方法 Rのdplyrパッケージのmutate関数は新たに列を追加する関数です。 ここでは、mutate関数に文字列として与えた列に対して、paste関数で統合した結果を新たに追加する方法をお伝えします。 サンプルデータとして、統計的な学生の髪と目の色が収められているHairEyeColorを用います。 ただし、このサンプルデータはtableとなっておりますので、実際にはd […]
  • R スティール・ドゥワス(Steel-Dwass)法R スティール・ドゥワス(Steel-Dwass)法 スティール・ドゥワス(Steel-Dwass)法とは、テューキー(Tukey)法の多重比較に対応するノンパラメトリックな多重比較である。 スティール・ドゥワス法を簡単に言うと、正規分布を仮定しない各群間を順位を用いて多重比較で調べる方法である。 Rで、スティール・ドゥワス法を使う場合は、「スティール・ドゥワス(Steel-Dwass)の方法による多重比較」のページにソ […]
  • Ubuntu,R h2oパッケージのインストールの方法Ubuntu,R h2oパッケージのインストールの方法 Rのパッケージh2oは、さまざまなクラスタ環境内のニューラルネットワーク(ディープラーニング)、ランダムフォレスト、勾配ブースティングマシン、一般化線形モデルなどの並列分散機械学習アルゴリズムを計算するビッグデータのためのオープンソースの数学エンジンH2O用のRスクリプト機能である。 ここでは、ubuntu14.04環境下でh2oパッケージのインストールの仕方についてお […]
R ggpartyパッケージを用いた決定木の可視化