決定木とは、分類ルールを木構造で表したものである。分類したいデータを目的変数(従属変数)、分類するために用いるデータを説明変数(独立変数)という。目的変数がカテゴリデータなどの場合は「分類木」、連続値などの量的データの場合は「回帰木」と呼ばれる。

決定木の最大のメリットは、結果にグラフを用いることができるため、視覚的に確認できることである。

ここでは、R言語の「rpart」パッケージを用いて決定木について見ていこう。サンプルデータとして、Rに標準で含まれている「Titanic」を使わせていただいた。このサンプルデータはタイタニック号の乗客の属性情報と生死の情報が含まれている。生死を分けた要因を属性情報から分類するとどのようになるのかを見ていく。

まずは必要となるパッケージのインストールとロードを行う。「rpart」パッケージは決定木を行うためのものだが、「rpart.plot」と「partykit」パッケージは結果を視覚的に表示するために使うので、あらかじめインストールとロードをしておく。


> install.packages("rpart")
> install.packages("rpart.plot")
> install.packages("partykit")
> library(rpart)
> library(rpart.plot)
> library(partykit)

次に、サンプルデータを扱いやすい形に変更しておく。


> tmp <- data.frame(Titanic)
> df <- data.frame(
      Class = rep(tmp$Class, tmp$Freq),
      Sex = rep(tmp$Sex, tmp$Freq),
      Age = rep(tmp$Age, tmp$Freq),
      Survived = rep(tmp$Survived, tmp$Freq)
  )
> head(df)
  Class  Sex   Age Survived
1   3rd Male Child       No
2   3rd Male Child       No
3   3rd Male Child       No
4   3rd Male Child       No
5   3rd Male Child       No
6   3rd Male Child       No

決定木を実行するにはrpart関数を用いる。下の意味は、Survivedを目的変数、ClassとSexとAgeを説明変数として分類木を用いて、結果をctに格納している。そして結果をprint関数で表示している。


> ct <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class")
> print(ct)
n= 2201 

node), split, n, loss, yval, (yprob)
  * denotes terminal node

   1) root 2201 711 No (0.6769650 0.3230350)  
     2) Sex=Male 1731 367 No (0.7879838 0.2120162)  
       4) Age=Adult 1667 338 No (0.7972406 0.2027594) *
       5) Age=Child 64  29 No (0.5468750 0.4531250)  
        10) Class=3rd 48  13 No (0.7291667 0.2708333) *
        11) Class=1st,2nd 16   0 Yes (0.0000000 1.0000000) *
     3) Sex=Female 470 126 Yes (0.2680851 0.7319149)  
       6) Class=3rd 196  90 No (0.5408163 0.4591837) *
       7) Class=1st,2nd,Crew 274  20 Yes (0.0729927 0.9270073) *

この結果をもっと視覚的に分かりやすいグラフとして表示してみる。まずは、標準のplot関数を用いてみる。


> par(xpd = NA)
> plot(ct, branch = 0.8, margin = 0.05)
> text(ct, use.n = TRUE, all = TRUE)

decision-tree-classification-tree-rpart

次に、「rpart.plot」パッケージのrpart.plot関数を用いてみる。


> rpart.plot(ct, type = 1, uniform = TRUE, extra = 1, under = 1, faclen = 0)

decision-tree-classification-tree-rpart.plot

最後に、「partykit」パッケージのas.party関数用いてデータを変換したものをplot関数に用いてみる。


> plot(as.party(ct))

decision-tree-classification-tree-rpart-party

これらのグラフはそれぞれ見栄えが異なるので、気に入ったものを使えばよいと思うが、「partkit」パッケージを用いたものが、比較的誰にでもわかりやすいように感じる。

さて、このグラフからまず分かることは、生死を決定づけた主な要因は性別(Sex)であることが分かる。大人であれ子供であれ、良い部屋に泊まっていようとなかろうと、女性(Female)の乗客は生存率が高い。

また、男性(Male)であっても、子供(Child)で良い部屋に泊まっていた乗客の生存率は高い。

つまり、この決定木からはタイタニック号が今まさに沈没しようとしているとき、真っ先に女性や子供を優先的に避難させようとしたことが読み取れるのである。

このように、決定木を用いると、視覚的に様々なものが読み取れるため非常に便利であるが、データによっては、木構造が深く複雑になる場合がある。そのようなときに、あまり重要でない分類ルールを失くしてシンプルにする必要がある。このような方法は剪定と呼ばれる。

剪定とは、構築された木が深くなるほど、きちんと分類できているといえるが、過学習の可能性もある。そこで、あらかじめ定めたパラメータによって複雑さと制御する方法である。

どのように剪定を行うのが良いかを判断するためには「rpart」パッケージのprintcp関数とグラフで表示できるplotcp関数を用いる。printcp関数は、分岐の数と複雑度を対応させて、plotcp関数は木のサイズと対応させている。どちらを用いてもよいが、基本的には、errorが収束し始めているところを剪定の基準にする場合が多い。


> printcp(ct)

Classification tree:
rpart(formula = Survived ~ Class + Sex + Age, data = df, method = "class")

Variables actually used in tree construction:
[1] Age   Class Sex  

Root node error: 711/2201 = 0.32303

n= 2201 

        CP nsplit rel error  xerror     xstd
1 0.306610      0   1.00000 1.00000 0.030857
2 0.022504      1   0.69339 0.69339 0.027510
3 0.011252      2   0.67089 0.69058 0.027470
4 0.010000      4   0.64838 0.66245 0.027062

> plotcp(ct)

decision-tree-classification-tree-rpart-cp

printcp関数の結果から剪定の基準をcp=0.022504として、再度決定木を行うと以下のようになる。


> ct2 <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class", cp = 0.022504)
> print(ct2)
n= 2201 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 2201 711 No (0.6769650 0.3230350)  
    2) Sex=Male 1731 367 No (0.7879838 0.2120162) *
    3) Sex=Female 470 126 Yes (0.2680851 0.7319149) *

> plot(as.party(ct2))

decision-tree-classification-tree-rpart-cp-prune

predict関数を用いると、予測が可能となる。ここでは、簡単のため、Titanicデータを二分割したものを用いる。
以下を見ると、2001番目のデータは生存した(Yes)確率が高いということが分かる。


> train <- df[1 : 2000,]
> test <- df[2001 : 2201,]
> ctp <- rpart(Survived ~ Class + Sex + Age, data = train, method = "class")
> p <- predict(ctp, newdata = test)
> head(p)
             No       Yes
2001 0.03333333 0.9666667
2002 0.03333333 0.9666667
2003 0.03333333 0.9666667
2004 0.03333333 0.9666667
2005 0.03333333 0.9666667
2006 0.03333333 0.9666667

関連する記事

  • Googleアナリティクスとコレスポンデンス分析を用いた年齢別のユーザー像の捉え方Googleアナリティクスとコレスポンデンス分析を用いた年齢別のユーザー像の捉え方 ページビュー数やコンバージョン率を上げるためには、良質なコンテンツが大切であるとよく言われる。そして、良質なコンテンツを作成するためには、ユーザー像を具体的に思い描き、そのユーザーに向けてコンテンツを作成しなくてはならない。 ここでは、ページビュー数から年齢とページの関係性を視覚的に確認し、年齢別にユーザーがどのコンテンツに興味を抱くか、その傾向を探っていく。この傾向が […]
  • Ubuntu,R h2oパッケージのインストールの方法Ubuntu,R h2oパッケージのインストールの方法 Rのパッケージh2oは、さまざまなクラスタ環境内のニューラルネットワーク(ディープラーニング)、ランダムフォレスト、勾配ブースティングマシン、一般化線形モデルなどの並列分散機械学習アルゴリズムを計算するビッグデータのためのオープンソースの数学エンジンH2O用のRスクリプト機能である。 ここでは、ubuntu14.04環境下でh2oパッケージのインストールの仕方についてお […]
  • R 文字列ベクトルで文字列を指定して要素を削除する方法R 文字列ベクトルで文字列を指定して要素を削除する方法 Rの文字列ベクトルで、文字列を指定して要素を削除する方法をお伝えする。 通常、ベクトルの要素を削除する場合は、次のように添字にマイナスを付加して削除する。 > s # 1番目の要素を削除 > s[-1] [1] "猫である。" "名前は" "まだ無い。" > # 1番目から2番目の要素を削除 > s[-1:-2] [1] "名前は" […]
  • これだけは抑えておきたい成長性分析の基本これだけは抑えておきたい成長性分析の基本 成長性分析とは、様々な観点から成長性・拡大性・発展性を測定する分析である。 規模拡大などの経営戦略がいつも経営者の思い描く通りに進むとは限らないため、様々な観点から自社および自社を取り巻く環境の状況を把握する必要がある。 ここでは、成長性分析の代表的な指標をいくつか紹介する。 売上高伸び率 売上高伸び率とは、前期売上高より当期売上高がどの程度上昇または下降したかを […]
  • R言語 CRAN Task View:ケモメトリックスと計算物理学R言語 CRAN Task View:ケモメトリックスと計算物理学 CRAN Task View: Chemometrics and Computational Physicsの英語での説明文をGoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。 Maintainer: Katharine Mullen Contact: katharine.mullen at […]
決定木 – 分類木

決定木 – 分類木」への3件のフィードバック

コメントは受け付けていません。