ここでは、決定木の目的変数が連続値である場合の回帰木について、R言語の「rpart」パッケージを用いて簡単に見ていく。

まずは必要となるパッケージのインストールとロードを行う。「rpart」パッケージは決定木を行うためのものだが、「rpart.plot」と「partykit」パッケージは結果を視覚的に表示するために使うので、あらかじめインストールとロードをしておく。


> install.packages("rpart")
> install.packages("rpart.plot")
> install.packages("partykit")
> library(rpart)
> library(rpart.plot)
> library(partykit)

サンプルデータとして、「rpart」パッケージに含まれる「cu.summary」を使わせていただいた。


> head(cu.summary)
                Price Country Reliability Mileage  Type
Acura Integra 4 11950   Japan Much better      NA Small
Dodge Colt 4     6851   Japan              NA Small
Dodge Omni 4     6995     USA  Much worse      NA Small
Eagle Summit 4   8895     USA      better      33 Small
Ford Escort   4  7402     USA       worse      33 Small
Ford Festiva 4   6319   Korea      better      37 Small

決定木を実行するには、rpart関数を用いる。目的変数が連続値であれば、自動的に回帰木として扱われる。同様に、目的変数がカテゴリ値ならば分類木として扱われる。


> rt <- rpart(Price ~ Mileage + Type + Country, data = cu.summary)
> print(rt)
n= 117 

node), split, n, deviance, yval
      * denotes terminal node

   1) root 117 7407473000 15743.460  
     2) Type=Compact,Small,Sporty,Van 80 3322389000 13035.010  
       4) Country=Brazil,France,Japan,Japan/USA,Korea,Mexico,USA 69 1426421000 11555.160  
         8) Type=Small 21   50309830  7629.048 *
         9) Type=Compact,Sporty,Van 48  910790000 13272.830  
           18) Country=Japan/USA,Mexico,USA 29  482343500 12241.550 *
           19) Country=France,Japan 19  350528000 14846.890 *
       5) Country=Germany,Sweden 11  797004200 22317.730 *
     3) Type=Large,Medium 37 2229351000 21599.570  
       6) Country=France,Korea,USA 25 1021102000 18697.280  
         12) Type=Medium 18  741101600 17607.440 *
	     13) Type=Large 7  203645100 21499.710 *
       7) Country=England,Germany,Japan,Sweden 12  558955000 27646.000 *

この結果をもっと視覚的に分かりやすいグラフとして表示してみる。まずは、標準のplot関数を用いてみる。


> par(xpd = NA)
> plot(rt, branch = 0.8, margin = 0.05)
> text(rt, use.n = TRUE, all = TRUE)

decision-tree-regression-tree-rpart

次に、「rpart.plot」パッケージのrpart.plot関数を用いてみる。


> rpart.plot(rt, type = 1, uniform = TRUE, extra = 1, under = 1, faclen = 0)

decision-tree-regression-tree-rpart.plot

最後に、「partykit」パッケージのas.party関数用いてデータを変換したものをplot関数に用いてみる。


> plot(as.party(rt))

decision-tree-regression-tree-rpart-party

これらのグラフはそれぞれ見栄えが異なるので、気に入ったものを使えばよいと思うが、「partkit」パッケージを用いたものが、比較的誰にでもわかりやすいように感じる。

次に剪定を考えるため、printcp関数と、plotcp関数を実行してみる。


> printcp(rt)

Regression tree:
rpart(formula = Price ~ Mileage + Type + Country, data = cu.summary)

Variables actually used in tree construction:
[1] Country Type   

Root node error: 7407472615/117 = 63311732

n= 117 

        CP nsplit rel error  xerror    xstd
1 0.250522      0   1.00000 1.01365 0.15804
2 0.148359      1   0.74948 0.90282 0.16685
3 0.087654      2   0.60112 0.79992 0.15733
4 0.062818      3   0.51347 0.65730 0.11368
5 0.010519      4   0.45065 0.55595 0.10363
6 0.010308      5   0.44013 0.57370 0.10665
7 0.010000      6   0.42982 0.57370 0.10665

> plotcp(rt)

decision-tree-regression-tree-rpart-cp

plotcp関数の結果から剪定の基準をcp=0.026として、再度決定木を行うと以下のようになる。


> rt2 <- rpart(Price ~ Mileage + Type + Country, data = cu.summary, cp = 0.026)
> print(rt2)
n= 117 

node), split, n, deviance, yval
      * denotes terminal node

  1) root 117 7407473000 15743.460  
    2) Type=Compact,Small,Sporty,Van 80 3322389000 13035.010  
      4) Country=Brazil,France,Japan,Japan/USA,Korea,Mexico,USA 69 1426421000 11555.160  
        8) Type=Small 21   50309830  7629.048 *
        9) Type=Compact,Sporty,Van 48  910790000 13272.830 *
      5) Country=Germany,Sweden 11  797004200 22317.730 *
    3) Type=Large,Medium 37 2229351000 21599.570  
      6) Country=France,Korea,USA 25 1021102000 18697.280 *
      7) Country=England,Germany,Japan,Sweden 12  558955000 27646.000 *

decision-tree-regression-tree-rpart-cp-prune

predict関数を用いると、予測が可能となる。ここでは、簡単のため、cu.summaryデータを二分割したものを用いる。


> train <- cu.summary[1:100,]
> test <- cu.summary[101:112,]
> rtp <- rpart(Price ~ Mileage + Type + Country, data = train)
> p <- predict(rtp, newdata = test)
> print(p)
          Buick Electra V6          Buick Le Sabre V6 
                  11981.79                   11981.79 
      Cadillac Brougham V8       Cadillac De Ville V8 
                  11981.79                   11981.79 
      Chevrolet Caprice V8 Ford LTD Crown Victoria V8 
                  11981.79                   11981.79 
       Lincoln Town Car V8         Chevrolet Astro V6 
                  11981.79                   11981.79 
   Chevrolet Lumina APV V6            Dodge Caravan 4 
                  11981.79                   11981.79 
    Dodge Grand Caravan V6           Ford Aerostar V6 
                  11981.79                   11981.79 

関連する記事

  • 統計的因果推論による傾向スコアとIPW推定量の基本的な考え方統計的因果推論による傾向スコアとIPW推定量の基本的な考え方 [latexpage] 統計的因果推論による因果効果を調べる手段として、傾向スコアとIPW推定量という概念があります。ここでは、なぜ傾向スコアを考えるのか、傾向スコアの逆数の重み付けはどのような意味があるのかを、複雑な数式を用いずに具体例を通してご説明します。 さっそくですが、次の具体例を考えます。 […]
  • R qgraphを用いてデータをネットワークとして可視化するR qgraphを用いてデータをネットワークとして可視化する qgraphは、ネットワークとしてデータを視覚化するために使用することができ、加重グラフィカルモデルを視覚化するためのインタフェースを提供しているパッケージです。 リファレンスマニュアルには、関数のサンプルコードのみで出力されたグラフがありません。そこで、qgraphのサンプルコードと合わせてグラフを並べてみました。 qgraph library(qgraph) […]
  • R スミルノフ・グラブス検定を繰り返し用いて外れ値を除去する方法 スミルノフ・グラブス検定は、正規分布を仮定した標本において、最大値または最小値が外れ値かどうか判定する検定の一つです。 外れ値を除去する際、外れ値を一つずつ検証することよりも、外れ値がすべて除去されたデータだけがほしいときもあると思います。 ここでは、正規分布を仮定したデータからスミルノフ・グラブス検定を繰り返し用いて外れ値を除去するソースコードをご紹介します。 こ […]
  • R knitrできれいな多重クロス集計をPDFで出力する方法R knitrできれいな多重クロス集計をPDFで出力する方法 knitrパッケージのkable関数を使えば、matrixやdata.frameなどの表形式をきれいに出力してくれるが、ftable関数を用いた多重クロス集計の結果は、kable関数を使うことができない。 これは非常に残念なので、他の方法できれいに出力する方法をお伝えする。ちなみにこの方法ではPDF出力のみの対応となるので注意してほしい。 手順を簡単に説明すると、 […]
  • R言語 CRAN Task View:関数データ解析R言語 CRAN Task View:関数データ解析 CRAN Task View: Functional Data Analysisの英語での説明文をGoogle翻訳を使用させていただき機械的に翻訳したものを掲載しました。 Maintainer: Fabian Scheipl Contact: fabian.scheipl at […]
決定木 – 回帰木