Repository where I mostly put random python scripts.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
5.8 KiB

  1. """
  2. :Author: james
  3. :Date: 21/10/2019
  4. :License: MIT
  5. :name: DecisionTree.py
  6. Basic implementation of a binary decision tree algorithm, with one
  7. discriminant per node.
  8. """
  9. import numpy as np
  10. from sklearn import datasets
  11. def proportion_k(ym):
  12. """
  13. Get the proportions of each class in the current set of values.
  14. :param ym: y values (class) of the data at a given node.
  15. :return:
  16. """
  17. counts = list(np.unique(ym, return_counts=True))
  18. counts[1] = counts[1]/(ym.shape[0])
  19. return counts
  20. def gini(k_proportions):
  21. """
  22. Gini impurity function.
  23. :param k_proportions:
  24. :return:
  25. """
  26. return (k_proportions*(1-k_proportions)).sum()
  27. def node_impurity(ym):
  28. """
  29. Calculate the impurity of data at a given node of the tree.
  30. :param ym:
  31. :return:
  32. """
  33. if ym.shape[0] == 0:
  34. return {"impurity": 0, "max_group": 0}
  35. k_prop = proportion_k(ym)
  36. return {"impurity": gini(k_prop[1]), "max_group": k_prop[0][np.argmax(k_prop[1])]}
  37. def disc_val_impurity(yleft, yright):
  38. """
  39. Calculate the level of impurity left in the given data split.
  40. :param yleft:
  41. :param yright:
  42. :return:
  43. """
  44. nleft = yleft.shape[0]
  45. nright = yright.shape[0]
  46. ntot = nleft + nright
  47. left_imp = node_impurity(yleft)
  48. right_imp = node_impurity(yright)
  49. return {
  50. "impurity": ((nleft/ntot)*left_imp["impurity"])+((nright/ntot)*right_imp["impurity"]),
  51. "lmax_group": left_imp["max_group"],
  52. "rmax_group": right_imp["max_group"]
  53. }
  54. def niave_min_impurity(xm, ym):
  55. minxs = xm.min(axis=0)
  56. maxxs = xm.max(axis=0)
  57. # discriminator with the smallest impurity.
  58. cur_min_disc = None
  59. for x_idx, (dmin, dmax) in enumerate(zip(minxs, maxxs)):
  60. disc_vals = np.linspace(dmin, dmax, 10)
  61. for disc_val in disc_vals:
  62. selection = xm[:, x_idx] < disc_val
  63. yleft = ym[selection]
  64. yright = ym[selection==False]
  65. imp = disc_val_impurity(yleft, yright)
  66. try:
  67. if cur_min_disc["impurity"] > imp["impurity"]:
  68. imp["discriminator"] = x_idx
  69. imp["val"] = disc_val
  70. cur_min_disc = imp
  71. except TypeError:
  72. imp["discriminator"] = x_idx
  73. imp["val"] = disc_val
  74. cur_min_disc = imp
  75. return cur_min_disc
  76. class BinaryTreeClassifier:
  77. def __init__(self, max_depth=4, min_data=5):
  78. tree = dict()
  79. self.depth = max_depth
  80. self.min_data = min_data
  81. def _node_mask(X, node):
  82. return X[:, node["discriminator"]] < node["val"]
  83. def _apply_disc(X, y, node):
  84. left_cond = BinaryTreeClassifier._node_mask(X, node)
  85. right_cond = left_cond == False
  86. left_X, left_y = X[left_cond], y[left_cond]
  87. right_X, right_y = X[right_cond], y[right_cond]
  88. return left_X, left_y, right_X, right_y
  89. def _tree_node(X, y, max_depth, min_data):
  90. node = niave_min_impurity(X, y)
  91. left_X, left_y, right_X, right_y = BinaryTreeClassifier._apply_disc(X, y, node)
  92. if max_depth > 0:
  93. if left_X.shape[0] >= min_data:
  94. node["left"] = BinaryTreeClassifier._tree_node(left_X, left_y, max_depth-1, min_data)
  95. if right_X.shape[0] >= min_data:
  96. node["right"] = BinaryTreeClassifier._tree_node(right_X, right_y, max_depth-1, min_data)
  97. return node
  98. def _run_tree(X, node):
  99. y = np.zeros(X.shape[0])
  100. left_cond = BinaryTreeClassifier._node_mask(X, node)
  101. right_cond = left_cond == False
  102. try:
  103. y[left_cond] = BinaryTreeClassifier._run_tree(X[left_cond], node["left"])
  104. except KeyError:
  105. y[left_cond] = node["lmax_group"]
  106. try:
  107. y[right_cond] = BinaryTreeClassifier._run_tree(X[right_cond], node["right"])
  108. except KeyError:
  109. y[right_cond] = node["rmax_group"]
  110. return y
  111. def _node_dict(node, idx=0):
  112. nodes = {}
  113. node_data = {"lmax_group": node["lmax_group"],
  114. "rmax_group": node["rmax_group"],
  115. "discriminator": node["discriminator"],
  116. "val": node["val"]}
  117. nodes[idx] = node_data
  118. try:
  119. left_idx = 2 * idx + 1
  120. nodes.update(BinaryTreeClassifier._node_dict(node["left"], left_idx))
  121. except KeyError:
  122. pass
  123. try:
  124. right_idx = 2 * idx + 2
  125. nodes.update(BinaryTreeClassifier._node_dict(node["right"], right_idx))
  126. except KeyError:
  127. pass
  128. return nodes
  129. def build_tree(self, X, y):
  130. self.tree = BinaryTreeClassifier._tree_node(X, y, self.depth, self.min_data)
  131. def classify(self, X):
  132. return BinaryTreeClassifier._run_tree(X, self.tree)
  133. def tree_to_heap_array(self):
  134. tree_dict = BinaryTreeClassifier._node_dict(self.tree)
  135. return [tree_dict[key] for key in sorted(tree_dict.keys())]
  136. def shuffle_split(x, y, frac=0.6):
  137. """
  138. Shuffle and split X and y data.
  139. :param x:
  140. :param y:
  141. :param frac:
  142. :return:
  143. """
  144. data_idx = np.arange(x.shape[0])
  145. sample1 = data_idx < (data_idx.max()*frac)
  146. np.random.shuffle(data_idx)
  147. np.random.shuffle(sample1)
  148. sample2 = sample1 == False
  149. x1, y1 = x[data_idx[sample1]], y[data_idx[sample1]]
  150. x2, y2 = x[data_idx[sample2]], y[data_idx[sample2]]
  151. return x1, y1, x2, y2
  152. if __name__ == "__main__":
  153. np.random.seed(10)
  154. iris_data = datasets.load_iris()
  155. X = iris_data["data"]
  156. y = iris_data["target"]
  157. X_train, y_train, X_test, y_test = shuffle_split(X, y)
  158. classifier = BinaryTreeClassifier()
  159. classifier.build_tree(X_train, y_train)
  160. result = classifier.classify(X_test)
  161. print("accuracy:", (result == y_test).sum()/(result.shape[0]))
  162. tree_arr = classifier.tree_to_heap_array()
  163. pass