決定木の可視化において、とても柔軟性が高いggpartyパッケージをご紹介します。

ggpartyパッケージは、ggplot2の機能をpartykitに拡張し、partyクラスのツリーオブジェクトのために明瞭に構造化され、高度にカスタマイズ可能なビジュアライゼーションを作成するために必要なツールを提供します。

ggpartyパッケージを用いると、ノードやエッジに対して様々な設定をすることで多様な表現が可能になります。
特に強力なのは最下段のノードにおいて、分類木の場合は通常であれば目的変数の100%積み上げ棒グラフとなりますが、個数の積み上げ棒グラフにしたり、ある説明変数でグループ化なども可能になります。
同様に、回帰木の場合は通常であれば目的変数の箱ひげ図となりますが、散布図にしたり、ある説明変数でグループ化なども可能になります。

ここでは簡単のため、rpartパッケージのrpart関数の結果オブジェクトをpartykitパッケージのas.party関数を用いてpartyクラスのツリーオブジェクトに変換したものを用います。

分類木

サンプルデータとして、Rに標準で含まれているTitanicを用います。
最初に必要なパッケージのロードとサンプルデータを扱いやすい形に変更しておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

tmp <- data.frame(Titanic)
df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)

head関数を用いてdfを確認しておきます。


head(df)

結果は次になります。


  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

次に、Survivedを目的変数、ClassとSex、Ageを説明変数として分類木をrpart関数を用いて実行し、結果をctに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


ct <- rpart(Survived ~ Class + Sex + Age , data = df)
pct <- as.party(ct)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

回帰木

サンプルデータとして、rpartパッケージに含まれているcu.summaryを用います。
最初に必要なパッケージのロードしておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

head関数を用いてcu.summaryを確認しておきます。


head(cu.summary)

結果は次になります。


> head(cu.summary)
                Price Country Reliability Mileage  Type
Acura Integra 4 11950   Japan Much better      NA Small
Dodge Colt 4     6851   Japan              NA Small
Dodge Omni 4     6995     USA  Much worse      NA Small
Eagle Summit 4   8895     USA      better      33 Small
Ford Escort   4  7402     USA       worse      33 Small
Ford Festiva 4   6319   Korea      better      37 Small

次に、Priceを目的変数、MileageとCountry、Typeを説明変数として回帰木をrpart関数を用いて実行し、結果をrtに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


rt <- rpart(Price ~ Mileage + Country + Type, data = cu.summary)
prt <- as.party(rt)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。
ここでは、表現力の自由度を確認するために、最下段のノードにおいて、Type別に箱ひげ図を表示するように設定しました。


g <- ggparty(prt, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 3)
g <- g + geom_node_plot(
  gglist = list(geom_boxplot(aes(x = "", y = Price, fill = Type)), theme_bw(base_size = 12)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 10,
    col = "black",
    fontface = "bold"
  ),
  list(size = 12)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 3,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

補足

ここでは、いくつかの補足事項をお伝えさせていただきます。

エッジラベルは回転させることはできない

ggpartyのVersion 1.0.0では、エッジのラベルを回転させることはできません。
これは、ggpartyの該当ソースコードの259行目にあるgeom_edge_label関数を見るとgeom_label関数が使用されており、ggplot2のgeom_label関数のリファレンスを見ると、geom_label関数はangleをサポートしていないことが確認できます。

エッジのラベルを個別に移動することはできない

ggpartyのVersion 1.0.0では、エッジのラベルをgeom_edge_label関数の引数nudge_xとnudge_yを用いて移動させることができます。
しかし、この方法はエッジのラベルに対して個別に指定できず、全体が対象となります。

分類木の終点ノードのy軸をパーセント表示にする

ソースコードは、次になります。
上記の分類木のコードとの相違点は、geom_node_plot関数の引数gglistのlist内です。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"),
                theme_bw(base_size = 15),
                ylab("Percent"),
                scale_y_continuous(labels = scales::percent)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することはできない

ggpartyのVersion 1.0.0では、ソースコードを確認すると分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することは、おそらくできません。

最後に

ggpartyパッケージを簡単にご紹介させていただきました。
ここでは、棒グラフgeom_barや箱ひげ図geom_boxplotを用いましたが、散布図やヒストグラムなども用いることが可能です。
これらはggplot2パッケージの関数ですので、ggplot2の扱いに慣れていれば、ここでご紹介させていただいたグラフよりも強力な可視化が可能になります。
決定木の可視化にお役に立てたならば幸いです。

参考

関連する記事

  • R言語 CRAN Task View:時空間データの処理と分析R言語 CRAN Task View:時空間データの処理と分析 CRAN Task View: Handling and Analyzing Spatio-Temporal Dataの英語での説明文をGoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。 Maintainer: Edzer Pebesma Contact: edzer.pebesma at […]
  • QGIS プラグイン一覧QGIS プラグイン一覧 オープンソースソフトウェアの地図情報システムの一つであるQGIS(Quantum GIS)のプラグインの一覧をご紹介します。英語での説明文をgoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。プラグインを探す参考にしていただければ幸いです。 パッケージ確認日:2021/04/01 パッケージ数:1424 3D City […]
  • 基本統計量基本統計量 [latexpage] 基本統計量とは、データの基本的な特徴を表す値のことで、代表値と散布度に区分できる。代表値とは、データを代表するような値のことで、例えば、平均値、最大値、最小値などがある。散布度とは、データの散らばり度合いを表すような値のことで、例えば、分散、標準偏差などがある。 平均値 […]
  • R実装と解説 対応のない2標本の母平均の差の検定(母分散が等しい) [latexpage] 母分散が等しい場合の対応のない2標本の母平均の差の検定とは、2つの母集団が正規分布に従い、ともに母分散が等しいと仮定できるとき、一方の母平均が他方の母平均と「異なる」または「大きい」、「小さい」かどうかを、検定統計量がt分布に従うことを利用して検定します。 統計的検定の流れ 検定の大まかな流れを確認しておきます。 […]
  • R言語 CRAN Task View:統計遺伝学R言語 CRAN Task View:統計遺伝学 CRAN Task View: Statistical Geneticsの英語での説明文をGoogle翻訳を使用させていただき機械的に翻訳したものを掲載した。 Maintainer: Giovanni Montana Contact: g.montana at […]
R ggpartyパッケージを用いた決定木の可視化