Over a million developers have joined DZone.

How Could Classification Trees Be So Fast on Categorical Variables?

DZone's Guide to

How Could Classification Trees Be So Fast on Categorical Variables?

A tutorial on how to create classification trees and calculating probability in Java.

· Big Data Zone
Free Resource

Need to build an application around your data? Learn more about dataflow programming for rapid development and greater creativity. 

I think that over the past months, I have been saying non-correct things about classification with categorical covariates. Because I never took time to look at it carefuly, consider some simulated dataset, with a logistic regression:

> n=1e3
> set.seed(1)
> X1=runif(n)
> q=quantile(X1,(0:26)/26)
> q[1]=0
> X2=cut(X1,q,labels=LETTERS[1:26])
> p=exp(-.1+qnorm(2*(abs(.5-X1))))/(1+exp(-.1+qnorm(2*(abs(.5-X1)))))
> Y=rbinom(n,size=1,p)
> df=data.frame(X1=X1,X2=X2,p=p,Y=Y)

Here, we use some continuous covariate, except that is considered as not-observed. Instead, we have a categorical covariate with 26 categories. The (theoretical) relationship between the covariate and the probability is given below:

> vx1=seq(0,1,by=.001)
> vp=exp(-.1+qnorm(2*(abs(.5-vx1))))/(1+exp(-.1+qnorm(2*(abs(.5-vx1)))))
> plot(vx1,vp,type="l")

And the empirical probability, for each modality is:

If we run a classification tree, we get:

> library(rpart)
> tree=rpart(Y~X2,data=df)
> library(rpart.plot)
> prp(tree, type=2, extra=1)

To be more specific, the output is here: 

> tree
1) root 1000 249.90000 0.4900000  
  2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.3 0.302
    4) X2=J,K,L,M,N,O,P,Q,R 346  65.12 0.25144  *
    5) X2=F,G,H,I 153  37.22876 0.4183007       *
  3) X2=A,B,C,D,E,S,T,U,V,W,X,Y,Z 501 109.61 0.67
    6) X2=B,C,D,E,S,T,U,V,W,X 385  90.38 0.623  *
    7) X2=A,Y,Z 116  14.50862 0.8534483         *

Note that it takes less than a second to get that output. So clearly, we did not look for all combinations between modalities. For the first node, there are like  possible groups, i.e.

> 67108864

It is big… not huge, but too big to try all combinations, since that’s only the first node, and we have to do it again on the two leaves, etc. Antoine (aka @ly_antoine) told me – while we were having a coffee after lunch today – the trick to get a fast algorithm, on categories. And as usual, the idea is very clever…

First, we need a function to compute Gini index:

> gini=function(y,classe){
+    T=table(y,classe)
+    nx=apply(T,2,sum)
+    n=sum(T)
+    pxy=T/matrix(rep(nx,each=2),nrow=2)
+    omega=matrix(rep(nx,each=2),nrow=2)/n
+    g=-sum(omega*pxy*(1-pxy))
+    return(g)}

For the first node, the idea is very simple:

Compute empirical averages:

> cond_prob=aggregate(df$Y,by=list(df$X2),mean)

Then sort those values:

Based on that ordering, consider 

> Group_Letters=cond_prob[order(cond_prob$x),2]

Then consider (only)  possible partitions, 


> v_gini=rep(NA,26)
> for(v in 1:26){
+   CLASSE=df$X2 %in% Group_Letters[1:v]
+   v_gini[v]=gini(y=df$Y,classe=CLASSE)
+ }

If we plot them, we get:

> plot(1:26,v_gini,type="b)

As for continuous variables, we seek for the maximum value, and then, we have our two groups:

> sort(Group_Letters[1:which.max(v_gini)])
 [1] F G H I J K L M N O P Q R

That’s exactly what we got with the tree function in R:

1) root 1000 249.90000 0.4900000  
  2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.30 0.30

Now, consider the leaf on the left (for instance):

> sub_df=df[df$X2 %in% sort(Group_Letters[1:which.max(v_gini)]),]

Then use the same algorithm as before: sort the conditional means:

> cond_prob=aggregate(sub_df$Y,by=
+ list(sub_df$X2),mean)
> s_Group_Letters=cond_prob[order(cond_prob$x),2]

Then compute Gini indices based on groups obtained from that ordering:

> v_gini=rep(NA,length(sub_Group_Letters))
> for(v in 1:length(sub_Group_Letters)){
+   CLASSE=sub_df$X2 %in% s_Group_Letters[1:v]
+   v_gini[v]=gini(y=sub_df$Y,classe=CLASSE)
+ }

If we plot it, we get our two groups:

> plot(1:length(s_Group_Letters),v_gini,type="b")

And the first group is here:

> sort(sub_Group_Letters[1:which.max(v_gini)])
[1] J K L M N O P Q R

Again, that’s exactly what we got with the R function:

1) root 1000 249.90000 0.4900000  
  2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.30 0.30
    4) X2=J,K,L,M,N,O,P,Q,R 346  65.12 0.25144  *

Clever, isn’t?

Check out the Exaptive data application Studio. Technology agnostic. No glue code. Use what you know and rely on the community for what you don't. Try the community version.

big data ,analysis ,java

Published at DZone with permission of Arthur Charpentier, DZone MVB. See the original article here.

Opinions expressed by DZone contributors are their own.


Dev Resources & Solutions Straight to Your Inbox

Thanks for subscribing!

Awesome! Check your inbox to verify your email so you can start receiving the latest in tech news and resources.


{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}