決定木とは、分類ルールを木構造で表したものである。分類したいデータを目的変数(従属変数)、分類するために用いるデータを説明変数(独立変数)という。目的変数がカテゴリデータなどの場合は「分類木」、連続値などの量的データの場合は「回帰木」と呼ばれる。

決定木の最大のメリットは、結果にグラフを用いることができるため、視覚的に確認できることである。

ここでは、R言語の「rpart」パッケージを用いて決定木について見ていこう。サンプルデータとして、Rに標準で含まれている「Titanic」を使わせていただいた。このサンプルデータはタイタニック号の乗客の属性情報と生死の情報が含まれている。生死を分けた要因を属性情報から分類するとどのようになるのかを見ていく。

まずは必要となるパッケージのインストールとロードを行う。「rpart」パッケージは決定木を行うためのものだが、「rpart.plot」と「partykit」パッケージは結果を視覚的に表示するために使うので、あらかじめインストールとロードをしておく。


> install.packages("rpart")
> install.packages("rpart.plot")
> install.packages("partykit")
> library(rpart)
> library(rpart.plot)
> library(partykit)

次に、サンプルデータを扱いやすい形に変更しておく。


> tmp <- data.frame(Titanic)
> df <- data.frame(
      Class = rep(tmp$Class, tmp$Freq),
      Sex = rep(tmp$Sex, tmp$Freq),
      Age = rep(tmp$Age, tmp$Freq),
      Survived = rep(tmp$Survived, tmp$Freq)
  )
> head(df)
  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

決定木を実行するにはrpart関数を用いる。下の意味は、Survivedを目的変数、ClassとSexとAgeを説明変数として分類木を用いて、結果をctに格納している。そして結果をprint関数で表示している。


> ct <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class")
> print(ct)
n= 2201 

node), split, n, loss, yval, (yprob)
  * denotes terminal node

   1) root 2201 711 No (0.6769650 0.3230350)  
     2) Sex=Male 1731 367 No (0.7879838 0.2120162)  
       4) Age=Adult 1667 338 No (0.7972406 0.2027594) *
       5) Age=Child 64  29 No (0.5468750 0.4531250)  
        10) Class=3rd 48  13 No (0.7291667 0.2708333) *
        11) Class=1st,2nd 16   0 Yes (0.0000000 1.0000000) *
     3) Sex=Female 470 126 Yes (0.2680851 0.7319149)  
       6) Class=3rd 196  90 No (0.5408163 0.4591837) *
       7) Class=1st,2nd,Crew 274  20 Yes (0.0729927 0.9270073) *

この結果をもっと視覚的に分かりやすいグラフとして表示してみる。まずは、標準のplot関数を用いてみる。


> par(xpd = NA)
> plot(ct, branch = 0.8, margin = 0.05)
> text(ct, use.n = TRUE, all = TRUE)

decision-tree-classification-tree-rpart

次に、「rpart.plot」パッケージのrpart.plot関数を用いてみる。


> rpart.plot(ct, type = 1, uniform = TRUE, extra = 1, under = 1, faclen = 0)

decision-tree-classification-tree-rpart.plot

最後に、「partykit」パッケージのas.party関数用いてデータを変換したものをplot関数に用いてみる。


> plot(as.party(ct))

decision-tree-classification-tree-rpart-party

これらのグラフはそれぞれ見栄えが異なるので、気に入ったものを使えばよいと思うが、「partkit」パッケージを用いたものが、比較的誰にでもわかりやすいように感じる。

さて、このグラフからまず分かることは、生死を決定づけた主な要因は性別(Sex)であることが分かる。大人であれ子供であれ、良い部屋に泊まっていようとなかろうと、女性(Female)の乗客は生存率が高い。

また、男性(Male)であっても、子供(Child)で良い部屋に泊まっていた乗客の生存率は高い。

つまり、この決定木からはタイタニック号が今まさに沈没しようとしているとき、真っ先に女性や子供を優先的に避難させようとしたことが読み取れるのである。

このように、決定木を用いると、視覚的に様々なものが読み取れるため非常に便利であるが、データによっては、木構造が深く複雑になる場合がある。そのようなときに、あまり重要でない分類ルールを失くしてシンプルにする必要がある。このような方法は剪定と呼ばれる。

剪定とは、構築された木が深くなるほど、きちんと分類できているといえるが、過学習の可能性もある。そこで、あらかじめ定めたパラメータによって複雑さと制御する方法である。

どのように剪定を行うのが良いかを判断するためには「rpart」パッケージのprintcp関数とグラフで表示できるplotcp関数を用いる。printcp関数は、分岐の数と複雑度を対応させて、plotcp関数は木のサイズと対応させている。どちらを用いてもよいが、基本的には、errorが収束し始めているところを剪定の基準にする場合が多い。


> printcp(ct)

Classification tree:
rpart(formula = Survived ~ Class + Sex + Age, data = df, method = "class")

Variables actually used in tree construction:
[1] Age   Class Sex  

Root node error: 711/2201 = 0.32303

n= 2201 

        CP nsplit rel error  xerror     xstd
1 0.306610      0   1.00000 1.00000 0.030857
2 0.022504      1   0.69339 0.69339 0.027510
3 0.011252      2   0.67089 0.69058 0.027470
4 0.010000      4   0.64838 0.66245 0.027062

> plotcp(ct)

decision-tree-classification-tree-rpart-cp

printcp関数の結果から剪定の基準をcp=0.022504として、再度決定木を行うと以下のようになる。


> ct2 <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class", cp = 0.022504)
> print(ct2)
n= 2201 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 2201 711 No (0.6769650 0.3230350)  
    2) Sex=Male 1731 367 No (0.7879838 0.2120162) *
    3) Sex=Female 470 126 Yes (0.2680851 0.7319149) *

> plot(as.party(ct2))

decision-tree-classification-tree-rpart-cp-prune

predict関数を用いると、予測が可能となる。ここでは、簡単のため、Titanicデータを二分割したものを用いる。
以下を見ると、2001番目のデータは生存した(Yes)確率が高いということが分かる。


> train <- df[1 : 2000,]
> test <- df[2001 : 2201,]
> ctp <- rpart(Survived ~ Class + Sex + Age, data = train, method = "class")
> p <- predict(ctp, newdata = test)
> head(p)
             No       Yes
2001 0.03333333 0.9666667
2002 0.03333333 0.9666667
2003 0.03333333 0.9666667
2004 0.03333333 0.9666667
2005 0.03333333 0.9666667
2006 0.03333333 0.9666667

関連する記事

  • 決定木 – 回帰木決定木 – 回帰木 ここでは、決定木の目的変数が連続値である場合の回帰木について、R言語の「rpart」パッケージを用いて簡単に見ていく。 まずは必要となるパッケージのインストールとロードを行う。「rpart」パッケージは決定木を行うためのものだが、「rpart.plot」と「partykit」パッケージは結果を視覚的に表示するために使うので、あらかじめインストールとロードをしておく。 […]
  • カイ二乗検定 – 適合度検定カイ二乗検定 – 適合度検定 適合度検定とは、観測度数分布が期待度数分布と同じかどうかを統計的に確かめる方法である。 適合度検定を行う手順は次の通りである。 仮説を立てる。 帰無仮説 H0:観測度数分布と期待度数分布が同じ。 対立仮説 […]
  • Journal of Statistical Software: 記事一覧 Journal of Statistical Software の記事一覧をご紹介する。英語での説明文をgoogle翻訳を使用させていただき機械的に翻訳したものを掲載した。 確認日:2017/03/24 論文数:1089 Introduction to stream: An Extensible Framework for Data Stream […]
  • カイ二乗検定 – 独立性検定カイ二乗検定 – 独立性検定 独立性検定とは、クロス集計表を作成したとき、2つの属性が独立であるかどうかを統計的に判定する方法である。 独立性検定を行う手順は次の通りである。 仮説を立てる。 帰無仮説H0:属性Ai(i=1,...,m)とBj(j=1,...,n)は独立である。 対立仮説H1:属性Ai(i=1,...,m)とBj(j=1,...,n)は少なくとも一つ以上は独立でない。 […]
  • Ubuntu,R h2oパッケージのインストールの方法Ubuntu,R h2oパッケージのインストールの方法 Rのパッケージh2oは、さまざまなクラスタ環境内のニューラルネットワーク(ディープラーニング)、ランダムフォレスト、勾配ブースティングマシン、一般化線形モデルなどの並列分散機械学習アルゴリズムを計算するビッグデータのためのオープンソースの数学エンジンH2O用のRスクリプト機能である。 ここでは、ubuntu14.04環境下でh2oパッケージのインストールの仕方についてお […]
決定木 – 分類木