決定木の可視化において、とても柔軟性が高いggpartyパッケージをご紹介します。

ggpartyパッケージは、ggplot2の機能をpartykitに拡張し、partyクラスのツリーオブジェクトのために明瞭に構造化され、高度にカスタマイズ可能なビジュアライゼーションを作成するために必要なツールを提供します。

ggpartyパッケージを用いると、ノードやエッジに対して様々な設定をすることで多様な表現が可能になります。
特に強力なのは最下段のノードにおいて、分類木の場合は通常であれば目的変数の100%積み上げ棒グラフとなりますが、個数の積み上げ棒グラフにしたり、ある説明変数でグループ化なども可能になります。
同様に、回帰木の場合は通常であれば目的変数の箱ひげ図となりますが、散布図にしたり、ある説明変数でグループ化なども可能になります。

ここでは簡単のため、rpartパッケージのrpart関数の結果オブジェクトをpartykitパッケージのas.party関数を用いてpartyクラスのツリーオブジェクトに変換したものを用います。

分類木

サンプルデータとして、Rに標準で含まれているTitanicを用います。
最初に必要なパッケージのロードとサンプルデータを扱いやすい形に変更しておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

tmp <- data.frame(Titanic)
df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)

head関数を用いてdfを確認しておきます。


head(df)

結果は次になります。


  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

次に、Survivedを目的変数、ClassとSex、Ageを説明変数として分類木をrpart関数を用いて実行し、結果をctに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


ct <- rpart(Survived ~ Class + Sex + Age , data = df)
pct <- as.party(ct)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

回帰木

サンプルデータとして、rpartパッケージに含まれているcu.summaryを用います。
最初に必要なパッケージのロードしておきます。


library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

head関数を用いてcu.summaryを確認しておきます。


head(cu.summary)

結果は次になります。


> head(cu.summary)
                Price Country Reliability Mileage  Type
Acura Integra 4 11950   Japan Much better      NA Small
Dodge Colt 4     6851   Japan              NA Small
Dodge Omni 4     6995     USA  Much worse      NA Small
Eagle Summit 4   8895     USA      better      33 Small
Ford Escort   4  7402     USA       worse      33 Small
Ford Festiva 4   6319   Korea      better      37 Small

次に、Priceを目的変数、MileageとCountry、Typeを説明変数として回帰木をrpart関数を用いて実行し、結果をrtに格納します。
また、この結果をas.party関数を用いてpartyクラスのツリーオブジェクトに変換します。


rt <- rpart(Price ~ Mileage + Country + Type, data = cu.summary)
prt <- as.party(rt)

ggpartyパッケージのggparty関数やgeom_edge関数などを用いて可視化します。
ここでは、表現力の自由度を確認するために、最下段のノードにおいて、Type別に箱ひげ図を表示するように設定しました。


g <- ggparty(prt, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 3)
g <- g + geom_node_plot(
  gglist = list(geom_boxplot(aes(x = "", y = Price, fill = Type)), theme_bw(base_size = 12)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 10,
    col = "black",
    fontface = "bold"
  ),
  list(size = 12)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 3,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

可視化されたグラフは次になります。

補足

ここでは、いくつかの補足事項をお伝えさせていただきます。

エッジラベルは回転させることはできない

ggpartyのVersion 1.0.0では、エッジのラベルを回転させることはできません。
これは、ggpartyの該当ソースコードの259行目にあるgeom_edge_label関数を見るとgeom_label関数が使用されており、ggplot2のgeom_label関数のリファレンスを見ると、geom_label関数はangleをサポートしていないことが確認できます。

エッジのラベルを個別に移動することはできない

ggpartyのVersion 1.0.0では、エッジのラベルをgeom_edge_label関数の引数nudge_xとnudge_yを用いて移動させることができます。
しかし、この方法はエッジのラベルに対して個別に指定できず、全体が対象となります。

分類木の終点ノードのy軸をパーセント表示にする

ソースコードは、次になります。
上記の分類木のコードとの相違点は、geom_node_plot関数の引数gglistのlist内です。


g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"),
                theme_bw(base_size = 15),
                ylab("Percent"),
                scale_y_continuous(labels = scales::percent)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)

分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することはできない

ggpartyのVersion 1.0.0では、ソースコードを確認すると分類木の終点ノードの積み上げ棒グラフ内にパーセント表示することは、おそらくできません。

最後に

ggpartyパッケージを簡単にご紹介させていただきました。
ここでは、棒グラフgeom_barや箱ひげ図geom_boxplotを用いましたが、散布図やヒストグラムなども用いることが可能です。
これらはggplot2パッケージの関数ですので、ggplot2の扱いに慣れていれば、ここでご紹介させていただいたグラフよりも強力な可視化が可能になります。
決定木の可視化にお役に立てたならば幸いです。

参考

関連する記事

  • Ubuntu OpenCVをインストールする手順Ubuntu OpenCVをインストールする手順 Ubuntu16.04にOpenCV3.3をインストールする手順をお伝えいたします。 OpenCVは、JPEGとPNGの画像だけを扱うように、できるだけ最小の構成を目指します。 そのため、動画に関する設定はできるだけ除いた形でインストールを行います。 環境 今回の作業環境を確認しておきます。また、以下の作業はすべてターミナルにて行っております。 Ubuntuのバージ […]
  • 相関係数相関係数 相関係数とは2変量のデータ間の関係性の強弱を計る統計学的指標である。相関係数rがとる値の範囲は-1≦r≦1である。相関係数rの値により以下のように呼ばれる。 -1≦r<0ならば負の相関 r=0ならば無相関 0<r≦1ならば正の相関 一般的に、強弱も合わせて以下のように呼ばれる。 […]
  • R実装と解説 母平均の検定(母分散未知) [latexpage] 母分散が未知の場合の母平均の検定とは、母集団が正規分布に従い、母分散が未知のときに母平均が標本平均と「異なる」または「大きい」、「小さい」かどうかを、検定統計量がt分布に従うことを利用して検定します。 統計的検定の流れ 検定の大まかな流れを確認しておきます。 帰無仮説H0と対立仮設H1をたてます […]
  • R言語 CRAN Task View:水文データとモデリングR言語 CRAN Task View:水文データとモデリング CRAN Task View: Hydrological Data and Modelingの英語での説明文をGoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。 Maintainer: Sam Albers, Sam Zipper, Ilaria Prosdocimi Contact: sam.albers at […]
  • CakePHP:プラグイン・パッケージ一覧CakePHP:プラグイン・パッケージ一覧 CakePHPのプラグイン・パッケージのサイトで公開されているプラグイン・パッケージの一覧をGoogle翻訳を使用させていただき機械的に翻訳したものとあわせてご紹介する。プラグイン・パッケージの情報は2017年04月01日時点のものであることに注意していただきたい。何かのお役に立てれば幸いだ。 1.3 2.x 3.x 3.x 2.x API […]
R ggpartyパッケージを用いた決定木の可視化