项目作者: cuteboydot

项目描述 :
implementation of DecisionTree(CART)
高级语言: Java
项目地址: git://github.com/cuteboydot/DecisionTree.git
创建时间: 2017-05-14T13:06:56Z
项目社区:https://github.com/cuteboydot/DecisionTree

开源协议:

下载


DecisionTree

implementation of DecisionTree(CART)

cuteboydot@gmail.com

reference : http://edoc.hu-berlin.de/master/timofeev-roman-2004-12-20/PDF/timofeev.pdf

  • entropy, gini impurity and inforamtion gain

i = class0,1,.., Pi = cnt(node_class_i_data)/cnt(node_all_data)
entropy impurity = -( Sum(Pi * log2(Pi)) )

gini impurity = 1 - Sum(Pi * Pi)

information gain(IG) =
impurity(node) -
( cnt(node_left_data)/cnt(node_all_data) impurity(node_left) +
cnt(node_right_data)/cnt(node_all_data)
impurity(node_right) )

maxizing IG..

  • example : data set




  • test result








  • code
    ```java
    import java.util.*;

/**

  • Created by cuteboydot
    */

class DTData {
int id;
double[] feat = null;
int label; // 0:male, 1:female
}

class DTSplit {
int feat;
double value;
}

class DTNode {
int depth;
int maxdepth;
int direction; // 0:root, 1:left, 2:right
String path;
int cnt; // node count
double impurity;
DTSplit splitrule = null;
DTNode childLeft = null;
DTNode childRight = null;
ArrayList data = null;
}

public class Main {
static final int SIZE_FEAT = 2;
static final int SIZE_LABEL = 2;
static final int SIZE_RECORD = 12;
static final int SIZE_TEST = 2;

  1. public static void main(String[] args) {
  2. System.out.println("hello world~!");
  3. DTData[] data = new DTData[SIZE_RECORD] ;
  4. DTData[] test = new DTData[SIZE_TEST] ;
  5. data[0] = new DTData();
  6. data[0].id = 0;
  7. data[0].feat = new double[SIZE_FEAT];
  8. data[0].feat[0] = 1;
  9. data[0].feat[1] = 4;
  10. data[0].label = 0;
  11. data[1] = new DTData();
  12. data[1].id = 1;
  13. data[1].feat = new double[SIZE_FEAT];
  14. data[1].feat[0] = 3;
  15. data[1].feat[1] = 3;
  16. data[1].label = 0;
  17. data[2] = new DTData();
  18. data[2].id = 2;
  19. data[2].feat = new double[SIZE_FEAT];
  20. data[2].feat[0] = 4;
  21. data[2].feat[1] = 5;
  22. data[2].label = 0;
  23. data[3] = new DTData();
  24. data[3].id = 3;
  25. data[3].feat = new double[SIZE_FEAT];
  26. data[3].feat[0] = 5;
  27. data[3].feat[1] = 6;
  28. data[3].label = 0;
  29. data[4] = new DTData();
  30. data[4].id = 4;
  31. data[4].feat = new double[SIZE_FEAT];
  32. data[4].feat[0] = 6;
  33. data[4].feat[1] = 2;
  34. data[4].label = 0;
  35. data[5] = new DTData();
  36. data[5].id = 5;
  37. data[5].feat = new double[SIZE_FEAT];
  38. data[5].feat[0] = 3;
  39. data[5].feat[1] = 4;
  40. data[5].label = 1;
  41. data[6] = new DTData();
  42. data[6].id = 6;
  43. data[6].feat = new double[SIZE_FEAT];
  44. data[6].feat[0] = 4;
  45. data[6].feat[1] = 1;
  46. data[6].label = 1;
  47. data[7] = new DTData();
  48. data[7].id = 7;
  49. data[7].feat = new double[SIZE_FEAT];
  50. data[7].feat[0] = 5;
  51. data[7].feat[1] = 3;
  52. data[7].label = 1;
  53. data[8] = new DTData();
  54. data[8].id = 8;
  55. data[8].feat = new double[SIZE_FEAT];
  56. data[8].feat[0] = 7;
  57. data[8].feat[1] = 5;
  58. data[8].label = 1;
  59. data[9] = new DTData();
  60. data[9].id = 9;
  61. data[9].feat = new double[SIZE_FEAT];
  62. data[9].feat[0] = 8;
  63. data[9].feat[1] = 3;
  64. data[9].label = 1;
  65. data[10] = new DTData();
  66. data[10].id = 10;
  67. data[10].feat = new double[SIZE_FEAT];
  68. data[10].feat[0] = 2;
  69. data[10].feat[1] = 1;
  70. data[10].label = 0;
  71. data[11] = new DTData();
  72. data[11].id = 11;
  73. data[11].feat = new double[SIZE_FEAT];
  74. data[11].feat[0] = 6;
  75. data[11].feat[1] = 5;
  76. data[11].label = 1;
  77. test[0] = new DTData();
  78. test[0].id = 0;
  79. test[0].feat = new double[SIZE_FEAT];
  80. test[0].feat[0] = 2;
  81. test[0].feat[1] = 5;
  82. test[0].label = -1; // real answer = 0
  83. test[1] = new DTData();
  84. test[1].id = 1;
  85. test[1].feat = new double[SIZE_FEAT];
  86. test[1].feat[0] = 6;
  87. test[1].feat[1] = 4;
  88. test[1].label = -1; // real answer = 1
  89. DTNode trainTree = new DTNode();
  90. trainTree.data = new ArrayList<DTData>();
  91. for(int a=0; a<SIZE_RECORD; a++) {
  92. trainTree.data.add(data[a]);
  93. }
  94. trainTree.depth = 0;
  95. trainTree.maxdepth = 5;
  96. trainTree.direction = 0;
  97. trainTree.path = "ROOT";
  98. trainTree.cnt = trainTree.data.size();
  99. // train
  100. train(trainTree);
  101. System.out.println();
  102. // test
  103. for (int a=0; a<SIZE_TEST; a++) {
  104. test(trainTree, test[a]);
  105. System.out.println("TEST#" + test[a].id + " RESULT : " + test[a].label);
  106. }
  107. }
  108. static void train(DTNode tree) {
  109. expandTree(tree);
  110. }
  111. static void expandTree(DTNode tree) {
  112. tree.impurity = getImpurity(tree);
  113. String strDir = "";
  114. if(tree.direction == 0) strDir = "ROOT";
  115. else if(tree.direction == 1) strDir = "LEFT";
  116. else if(tree.direction == 2) strDir = "RIGHT";
  117. // terminal node
  118. if((tree.depth >= tree.maxdepth) || (tree.impurity == 0) || (tree.data.size() == 1)) {
  119. //System.out.println("DEPTH: " + tree.depth + ", DIRECTION: " + strDir + ", CNT: " + tree.cnt +
  120. System.out.println("PATH: " + tree.path + ", CNT: " + tree.cnt +
  121. "[TERMINAL], IMPURITY: " + String.format("%.3f",tree.impurity));
  122. return;
  123. }
  124. // make split criteria list
  125. ArrayList<Double>[] splitList = new ArrayList[SIZE_FEAT];
  126. for (int a=0; a<SIZE_FEAT; a++) {
  127. splitList[a] = new ArrayList<Double>();
  128. for (int b=0; b<tree.data.size(); b++) {
  129. splitList[a].add( tree.data.get(b).feat[a] );
  130. }
  131. // sort value list
  132. Collections.sort(splitList[a], new Comparator<Double>() {
  133. @Override
  134. public int compare(Double o1, Double o2) {
  135. return o1.compareTo(o2);
  136. }
  137. });
  138. // remove duplicated value
  139. for (int b=tree.data.size()-2; b>=0; b--) {
  140. Double pre = splitList[a].get(b);
  141. Double suc = splitList[a].get(b+1);
  142. if(Double.compare(pre, suc) == 0) {
  143. splitList[a].remove(b + 1);
  144. }
  145. }
  146. }
  147. for (int a=0; a<SIZE_FEAT; a++) {
  148. // split value = (data[pre] + data[suc]) / 2
  149. for (int b=0; b<splitList[a].size()-1; b++) {
  150. double val = (splitList[a].get(b) + splitList[a].get(b+1)) / 2;
  151. splitList[a].set(b, val);
  152. }
  153. splitList[a].remove(splitList[a].size()-1);
  154. }
  155. // allocate tmp sub tree by criteria rule
  156. // find maximum delta gain of sub trees
  157. int maxSplitFeat = 0;
  158. double maxSplitValue = 0;
  159. double maxSplitGain = 0;
  160. for (int a=0; a<SIZE_FEAT; a++) {
  161. for (int b=0; b<splitList[a].size(); b++) {
  162. double val = splitList[a].get(b);
  163. DTNode tmpLeftTree = new DTNode();
  164. DTNode tmpRightTree = new DTNode();
  165. for( int c=0; c<tree.data.size(); c++) {
  166. if(tree.data.get(c).feat[a] <= val) {
  167. if(tmpLeftTree.data == null) {
  168. tmpLeftTree.data = new ArrayList<DTData>();
  169. }
  170. tmpLeftTree.data.add(tree.data.get(c));
  171. } else {
  172. if(tmpRightTree.data == null) {
  173. tmpRightTree.data = new ArrayList<DTData>();
  174. }
  175. tmpRightTree.data.add(tree.data.get(c));
  176. }
  177. }
  178. double impLeft = getImpurity(tmpLeftTree);
  179. double impRight = getImpurity(tmpRightTree);
  180. double gain = tree.impurity -
  181. ( ((double)tmpLeftTree.data.size()/(double)tree.data.size()) * impLeft +
  182. ((double)tmpRightTree.data.size()/(double)tree.data.size()) * impRight );
  183. if(gain > maxSplitGain) {
  184. maxSplitFeat = a;
  185. maxSplitValue = val;
  186. maxSplitGain = gain;
  187. }
  188. }
  189. }
  190. tree.splitrule = new DTSplit();
  191. tree.splitrule.feat = maxSplitFeat;
  192. tree.splitrule.value = maxSplitValue;
  193. tree.childLeft = new DTNode();
  194. tree.childLeft.data = new ArrayList<DTData>();
  195. tree.childLeft.maxdepth = tree.maxdepth;
  196. tree.childLeft.depth = tree.depth + 1;
  197. tree.childLeft.direction = 1;
  198. tree.childLeft.path = tree.path + " -> " + tree.childLeft.depth + "L";
  199. tree.childRight = new DTNode();
  200. tree.childRight.data = new ArrayList<DTData>();
  201. tree.childRight.maxdepth = tree.maxdepth;
  202. tree.childRight.depth = tree.depth + 1;
  203. tree.childRight.direction = 2;
  204. tree.childRight.path = tree.path + " -> " + tree.childLeft.depth + "R";
  205. // make child tree
  206. for(int a=0; a<tree.data.size(); a++) {
  207. DTData data = tree.data.get(a);
  208. if(data.feat[tree.splitrule.feat] <= tree.splitrule.value) {
  209. tree.childLeft.data.add(data);
  210. } else {
  211. tree.childRight.data.add(data);
  212. }
  213. }
  214. tree.childLeft.cnt = tree.childLeft.data.size();
  215. tree.childRight.cnt = tree.childRight.data.size();
  216. //System.out.println("DEPTH: " + tree.depth + ", DIRECTION: " + strDir + ", CNT: " + tree.cnt +
  217. System.out.println("PATH: " + tree.path + ", CNT: " + tree.cnt +
  218. "[" + tree.childLeft.cnt + "," + tree.childRight.cnt + "]" +
  219. ", IMPURITY: " + String.format("%.3f",tree.impurity) + ", SPLIT FEAT: " + tree.splitrule.feat +
  220. ", SPLIT VAL: " + String.format("%.3f",tree.splitrule.value));
  221. // expand subtree recursively
  222. if(tree.childLeft.cnt >= 1)
  223. expandTree(tree.childLeft);
  224. if(tree.childRight.cnt >= 1)
  225. expandTree(tree.childRight);
  226. }
  227. static void test(DTNode trainTree, DTData test) {
  228. test.label = propagation(trainTree, test);
  229. }
  230. static int propagation(DTNode subTree, DTData test) {
  231. int label = 0;
  232. int cntClas[] = new int[SIZE_LABEL];
  233. int tmpCnt = 0;
  234. if(subTree.data.size() == 1) {
  235. label = subTree.data.get(0).label;
  236. return label;
  237. }
  238. for (int a=0; a<subTree.data.size(); a++) {
  239. DTData data = subTree.data.get(a);
  240. cntClas[data.label]++;
  241. }
  242. for (int a=0; a<SIZE_LABEL; a++) {
  243. tmpCnt = cntClas[a];
  244. if(tmpCnt > label)
  245. label = a;
  246. }
  247. // case of classified cluster or terminal node
  248. if(subTree.impurity == 0.0 || subTree.depth >= subTree.maxdepth) {
  249. return label;
  250. }
  251. if(test.feat[subTree.splitrule.feat] <= subTree.splitrule.value) {
  252. if(subTree.childLeft != null) {
  253. if(subTree.childLeft.data.size() > 1) {
  254. label = propagation(subTree.childLeft, test);
  255. }
  256. }
  257. } else {
  258. if(subTree.childRight != null) {
  259. if(subTree.childRight.data.size() > 1) {
  260. label = propagation(subTree.childRight, test);
  261. }
  262. }
  263. }
  264. cntClas = null;
  265. return label;
  266. }
  267. static double getImpurity(DTNode node) {
  268. double impurity;
  269. if(node.data.size() == 1) {
  270. impurity = 0;
  271. return impurity;
  272. }
  273. impurity = getEntropy(node);
  274. //impurity = getGini(node);
  275. return impurity;
  276. }
  277. static double getEntropy(DTNode node) {
  278. double impurity = 0;
  279. int cnt[] = new int [SIZE_LABEL];
  280. int totalcnt = node.data.size();
  281. for (int a=0; a<totalcnt; a++) {
  282. cnt[node.data.get(a).label]++;
  283. }
  284. for (int a=0; a<SIZE_LABEL; a++) {
  285. double prob = (double)cnt[a]/(double)totalcnt;
  286. double log2 = 0.0;
  287. if(prob == 0.0) {
  288. log2 = 0.0;
  289. } else {
  290. log2 = Math.log(prob)/Math.log(2);
  291. }
  292. impurity += (prob * log2);
  293. }
  294. impurity *= -1;
  295. return impurity;
  296. }
  297. static double getGini(DTNode node) {
  298. double impurity = 0;
  299. int cnt[] = new int [SIZE_LABEL];
  300. int totalcnt = node.data.size();
  301. for (int a=0; a<totalcnt; a++) {
  302. cnt[node.data.get(a).label]++;
  303. }
  304. for (int a=0; a<SIZE_LABEL; a++) {
  305. double prob = (double)cnt[a]/(double)totalcnt;
  306. impurity += (prob * prob);
  307. }
  308. impurity = 1 - impurity;
  309. return impurity;
  310. }

}
```