決定木とは、分類ルールを木構造で表したものである。分類したいデータを目的変数(従属変数)、分類するために用いるデータを説明変数(独立変数)という。目的変数がカテゴリデータなどの場合は「分類木」、連続値などの量的データの場合は「回帰木」と呼ばれる。

決定木の最大のメリットは、結果にグラフを用いることができるため、視覚的に確認できることである。

ここでは、R言語の「rpart」パッケージを用いて決定木について見ていこう。サンプルデータとして、Rに標準で含まれている「Titanic」を使わせていただいた。このサンプルデータはタイタニック号の乗客の属性情報と生死の情報が含まれている。生死を分けた要因を属性情報から分類するとどのようになるのかを見ていく。

まずは必要となるパッケージのインストールとロードを行う。「rpart」パッケージは決定木を行うためのものだが、「rpart.plot」と「partykit」パッケージは結果を視覚的に表示するために使うので、あらかじめインストールとロードをしておく。


> install.packages("rpart")
> install.packages("rpart.plot")
> install.packages("partykit")
> library(rpart)
> library(rpart.plot)
> library(partykit)

次に、サンプルデータを扱いやすい形に変更しておく。


> tmp <- data.frame(Titanic)
> df <- data.frame(
      Class = rep(tmp$Class, tmp$Freq),
      Sex = rep(tmp$Sex, tmp$Freq),
      Age = rep(tmp$Age, tmp$Freq),
      Survived = rep(tmp$Survived, tmp$Freq)
  )
> head(df)
  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

決定木を実行するにはrpart関数を用いる。下の意味は、Survivedを目的変数、ClassとSexとAgeを説明変数として分類木を用いて、結果をctに格納している。そして結果をprint関数で表示している。


> ct <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class")
> print(ct)
n= 2201 

node), split, n, loss, yval, (yprob)
  * denotes terminal node

   1) root 2201 711 No (0.6769650 0.3230350)  
     2) Sex=Male 1731 367 No (0.7879838 0.2120162)  
       4) Age=Adult 1667 338 No (0.7972406 0.2027594) *
       5) Age=Child 64  29 No (0.5468750 0.4531250)  
        10) Class=3rd 48  13 No (0.7291667 0.2708333) *
        11) Class=1st,2nd 16   0 Yes (0.0000000 1.0000000) *
     3) Sex=Female 470 126 Yes (0.2680851 0.7319149)  
       6) Class=3rd 196  90 No (0.5408163 0.4591837) *
       7) Class=1st,2nd,Crew 274  20 Yes (0.0729927 0.9270073) *

この結果をもっと視覚的に分かりやすいグラフとして表示してみる。まずは、標準のplot関数を用いてみる。


> par(xpd = NA)
> plot(ct, branch = 0.8, margin = 0.05)
> text(ct, use.n = TRUE, all = TRUE)

decision-tree-classification-tree-rpart

次に、「rpart.plot」パッケージのrpart.plot関数を用いてみる。


> rpart.plot(ct, type = 1, uniform = TRUE, extra = 1, under = 1, faclen = 0)

decision-tree-classification-tree-rpart.plot

最後に、「partykit」パッケージのas.party関数用いてデータを変換したものをplot関数に用いてみる。


> plot(as.party(ct))

decision-tree-classification-tree-rpart-party

これらのグラフはそれぞれ見栄えが異なるので、気に入ったものを使えばよいと思うが、「partkit」パッケージを用いたものが、比較的誰にでもわかりやすいように感じる。

さて、このグラフからまず分かることは、生死を決定づけた主な要因は性別(Sex)であることが分かる。大人であれ子供であれ、良い部屋に泊まっていようとなかろうと、女性(Female)の乗客は生存率が高い。

また、男性(Male)であっても、子供(Child)で良い部屋に泊まっていた乗客の生存率は高い。

つまり、この決定木からはタイタニック号が今まさに沈没しようとしているとき、真っ先に女性や子供を優先的に避難させようとしたことが読み取れるのである。

このように、決定木を用いると、視覚的に様々なものが読み取れるため非常に便利であるが、データによっては、木構造が深く複雑になる場合がある。そのようなときに、あまり重要でない分類ルールを失くしてシンプルにする必要がある。このような方法は剪定と呼ばれる。

剪定とは、構築された木が深くなるほど、きちんと分類できているといえるが、過学習の可能性もある。そこで、あらかじめ定めたパラメータによって複雑さと制御する方法である。

どのように剪定を行うのが良いかを判断するためには「rpart」パッケージのprintcp関数とグラフで表示できるplotcp関数を用いる。printcp関数は、分岐の数と複雑度を対応させて、plotcp関数は木のサイズと対応させている。どちらを用いてもよいが、基本的には、errorが収束し始めているところを剪定の基準にする場合が多い。


> printcp(ct)

Classification tree:
rpart(formula = Survived ~ Class + Sex + Age, data = df, method = "class")

Variables actually used in tree construction:
[1] Age   Class Sex  

Root node error: 711/2201 = 0.32303

n= 2201 

        CP nsplit rel error  xerror     xstd
1 0.306610      0   1.00000 1.00000 0.030857
2 0.022504      1   0.69339 0.69339 0.027510
3 0.011252      2   0.67089 0.69058 0.027470
4 0.010000      4   0.64838 0.66245 0.027062

> plotcp(ct)

decision-tree-classification-tree-rpart-cp

printcp関数の結果から剪定の基準をcp=0.022504として、再度決定木を行うと以下のようになる。


> ct2 <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class", cp = 0.022504)
> print(ct2)
n= 2201 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 2201 711 No (0.6769650 0.3230350)  
    2) Sex=Male 1731 367 No (0.7879838 0.2120162) *
    3) Sex=Female 470 126 Yes (0.2680851 0.7319149) *

> plot(as.party(ct2))

decision-tree-classification-tree-rpart-cp-prune

predict関数を用いると、予測が可能となる。ここでは、簡単のため、Titanicデータを二分割したものを用いる。
以下を見ると、2001番目のデータは生存した(Yes)確率が高いということが分かる。


> train <- df[1 : 2000,]
> test <- df[2001 : 2201,]
> ctp <- rpart(Survived ~ Class + Sex + Age, data = train, method = "class")
> p <- predict(ctp, newdata = test)
> head(p)
             No       Yes
2001 0.03333333 0.9666667
2002 0.03333333 0.9666667
2003 0.03333333 0.9666667
2004 0.03333333 0.9666667
2005 0.03333333 0.9666667
2006 0.03333333 0.9666667
決定木 – 分類木

決定木 – 分類木」への3件のフィードバック

コメントは受け付けていません。