決定木の可視化において、とても柔軟性が高いggpartyパッケージをご紹介します。

ggpartyパッケージは、ggplot2の機能をpartykitに拡張し、partyクラスのツリーオブジェクトのために明瞭に構造化され、高度にカスタマイズ可能なビジュアライゼーションを作成するために必要なツールを提供します。

ggpartyパッケージを用いると、ノードやエッジに対して様々な設定をすることで多様な表現が可能になります。
特に強力なのは最下段のノードにおいて、分類木の場合は通常であれば目的変数の100%積み上げ棒グラフとなりますが、個数の積み上げ棒グラフにしたり、ある説明変数でグループ化なども可能になります。
同様に、回帰木の場合は通常であれば目的変数の箱ひげ図となりますが、散布図にしたり、ある説明変数でグループ化なども可能になります。

ここでは簡単のため、rpartパッケージのrpart関数の結果オブジェクトをpartykitパッケージのas.party関数を用いてpartyクラスのツリーオブジェクトに変換したものを用います。

分類木

サンプルデータとして、Rに標準で含まれているTitanicを用います。
最初に必要なパッケージのロードとサンプルデータを扱いやすい形に変更しておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

tmp <- data.frame(Titanic)
df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)

head関数を用いてdfを確認しておきます。


head(df)

結果は次になります。


  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

次に、Survivedを目的変数、ClassとSex、Ageを説明変数として分類木をrpart関数を用いて実行し、結果をctに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


ct <- rpart(Survived ~ Class + Sex + Age , data = df)
pct <- as.party(ct)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

回帰木

サンプルデータとして、rpartパッケージに含まれているcu.summaryを用います。
最初に必要なパッケージのロードしておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

head関数を用いてcu.summaryを確認しておきます。


head(cu.summary)

結果は次になります。


> head(cu.summary)
                Price Country Reliability Mileage  Type
Acura Integra 4 11950   Japan Much better      NA Small
Dodge Colt 4     6851   Japan              NA Small
Dodge Omni 4     6995     USA  Much worse      NA Small
Eagle Summit 4   8895     USA      better      33 Small
Ford Escort   4  7402     USA       worse      33 Small
Ford Festiva 4   6319   Korea      better      37 Small

次に、Priceを目的変数、MileageとCountry、Typeを説明変数として回帰木をrpart関数を用いて実行し、結果をrtに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


rt <- rpart(Price ~ Mileage + Country + Type, data = cu.summary)
prt <- as.party(rt)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。
ここでは、表現力の自由度を確認するために、最下段のノードにおいて、Type別に箱ひげ図を表示するように設定しました。


g <- ggparty(prt, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 3)
g <- g + geom_node_plot(
  gglist = list(geom_boxplot(aes(x = "", y = Price, fill = Type)), theme_bw(base_size = 12)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 10,
    col = "black",
    fontface = "bold"
  ),
  list(size = 12)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 3,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

補足

ここでは、いくつかの補足事項をお伝えさせていただきます。

エッジラベルは回転させることはできない

ggpartyのVersion 1.0.0では、エッジのラベルを回転させることはできません。
これは、ggpartyの該当ソースコードの259行目にあるgeom_edge_label関数を見るとgeom_label関数が使用されており、ggplot2のgeom_label関数のリファレンスを見ると、geom_label関数はangleをサポートしていないことが確認できます。

エッジのラベルを個別に移動することはできない

ggpartyのVersion 1.0.0では、エッジのラベルをgeom_edge_label関数の引数nudge_xとnudge_yを用いて移動させることができます。
しかし、この方法はエッジのラベルに対して個別に指定できず、全体が対象となります。

分類木の終点ノードのy軸をパーセント表示にする

ソースコードは、次になります。
上記の分類木のコードとの相違点は、geom_node_plot関数の引数gglistのlist内です。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"),
                theme_bw(base_size = 15),
                ylab("Percent"),
                scale_y_continuous(labels = scales::percent)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することはできない

ggpartyのVersion 1.0.0では、ソースコードを確認すると分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することは、おそらくできません。

最後に

ggpartyパッケージを簡単にご紹介させていただきました。
ここでは、棒グラフgeom_barや箱ひげ図geom_boxplotを用いましたが、散布図やヒストグラムなども用いることが可能です。
これらはggplot2パッケージの関数ですので、ggplot2の扱いに慣れていれば、ここでご紹介させていただいたグラフよりも強力な可視化が可能になります。
決定木の可視化にお役に立てたならば幸いです。

参考

関連する記事

  • Ubuntu,R h2oパッケージのインストールの方法Ubuntu,R h2oパッケージのインストールの方法 Rのパッケージh2oは、さまざまなクラスタ環境内のニューラルネットワーク(ディープラーニング)、ランダムフォレスト、勾配ブースティングマシン、一般化線形モデルなどの並列分散機械学習アルゴリズムを計算するビッグデータのためのオープンソースの数学エンジンH2O用のRスクリプト機能である。 ここでは、ubuntu14.04環境下でh2oパッケージのインストールの仕方についてお […]
  • MySQL 月の差分を計算する方法MySQL 月の差分を計算する方法 MySQLで、月の差分を計算する方法をお伝えする。 計算は、PERIOD_DIFF関数を用いれば簡単に求めることができる。 これは、二つの期間の差の月数を返す関数である。 PERIOD_DIFF(P1, […]
  • 地図で見る石川県白山市の人口 2013年12月版地図で見る石川県白山市の人口 2013年12月版 白山市役所が公開している平成25年12月末日の住民基本台帳人口と総務省統計局が公開している地図データを基に人口、人口密度、世帯数などの数値および前年同月からの増減率を地図上に色分けして視覚化したものと上位・下位のランキングをご紹介する。 人口の上位・下位ランキング […]
  • R オブジェクトを保存・読み込みする方法R オブジェクトを保存・読み込みする方法 Rでオブジェクトをファイルに保存または読み込みする方法を記載します。 長時間の計算による解析結果をファイルに保存しておくことは、解析手続きの分割が行えるため、とても役に立ちます。 解析手続きの分割について、解析Aの結果を解析Bで用いる場合という例でご説明します。 同じスクリプトで解析Aと解析Bを記載すると、解析Bを変更した際に再度解析Aを実行しないといけません。解析 […]
  • Bioconductor Workflowパッケージ一覧Bioconductor Workflowパッケージ一覧 BioconductorのWorkflowパッケージの一覧をご紹介します。英語での説明文をgoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。パッケージを探す参考にしていただければ幸いです。 パッケージ確認日:2022/01/01 パッケージ数:29 1. rnaseqGene RNA-seq workflow: gene-level […]
R ggpartyパッケージを用いた決定木の可視化