Repository where I mostly put random python scripts.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

342 lines
12 KiB

  1. """
  2. :Author: James Sherratt
  3. :Date: 21/10/2019
  4. :License: MIT
  5. :name: DecisionTree.py
  6. Basic implementation of a binary decision tree algorithm, with one
  7. discriminant per node.
  8. Useful links:
  9. https://scikit-learn.org/stable/modules/tree.html
  10. https://en.wikipedia.org/wiki/Decision_tree
  11. """
  12. import numpy as np
  13. from sklearn import datasets
  14. def proportion_k(ym):
  15. """
  16. Get the proportions of each class in the current set of values.
  17. :param ym: y values (class) of the data at a given node.
  18. :return: list containing the classes and the fraction of those classes present.
  19. """
  20. counts = list(np.unique(ym, return_counts=True))
  21. counts[1] = counts[1]/(ym.shape[0])
  22. return counts
  23. def gini(k_proportions):
  24. """
  25. Gini impurity function. This is used to determine the impurity of a given
  26. set of data, given the proportions of the classes in the dataset.
  27. This is equivalent to:
  28. H = pk(1-pk) for all k classes.
  29. k_proportions, in this case, is an array of pk's
  30. :param k_proportions: array containing proportions of different classes. Proportions sum to 1.
  31. :return: the impurity of the dataset.
  32. """
  33. return (k_proportions*(1-k_proportions)).sum()
  34. def node_impurity(ym):
  35. """
  36. Calculate the impurity of data on one side of node after split.
  37. :param ym: Actual y data for the selected dataset.
  38. :return: dict containing the impurity value of the side and the most common class on that side.
  39. """
  40. if ym.shape[0] == 0:
  41. return {"impurity": 0, "max_class": 0}
  42. k_prop = proportion_k(ym)
  43. return {"impurity": gini(k_prop[1]), "max_class": k_prop[0][np.argmax(k_prop[1])]}
  44. def disc_val_impurity(yleft, yright):
  45. """
  46. Calculate the level of impurity left in the given data after splitting. This returns
  47. a dict which contains:
  48. - The impurity of the data after being split.
  49. - The class of the largest proportion on the left and right side of the split.
  50. The aim is to find a split which minimises impurity.
  51. The impurity calculated is:
  52. G = (nleft/ntot)*Hleft + (nright/ntot)*Hright
  53. This gives the impurity of the split data.
  54. :param yleft: Real/ training y values for the data on the left.
  55. :param yright: Real/ training y values for the data on the right.
  56. :return: Dict containing the data impurity after split and the most common class on the left and right of the split.
  57. """
  58. nleft = yleft.shape[0]
  59. nright = yright.shape[0]
  60. ntot = nleft + nright
  61. left_imp = node_impurity(yleft)
  62. right_imp = node_impurity(yright)
  63. return {
  64. "impurity": ((nleft/ntot)*left_imp["impurity"])+((nright/ntot)*right_imp["impurity"]),
  65. "lmax_class": left_imp["max_class"],
  66. "rmax_class": right_imp["max_class"]
  67. }
  68. def niave_min_impurity(xm, ym):
  69. """
  70. Find a discriminator which minimises the impurity of the data. The discriminator
  71. is used to split data at a node.
  72. This works by:
  73. 1. Selecting a data column as a discriminator.
  74. 2. Splitting the possible values of the discriminator into 1000 even spaced values
  75. (between the minimum and maximum value in the dataset).
  76. 3. Selecting the discriminator column + value which minimises the impurity.
  77. :param xm: Data on the left.
  78. :param ym: Data on the right.
  79. :return: dict containing the current niave minimum impurity.
  80. """
  81. minxs = xm.min(axis=0)
  82. maxxs = xm.max(axis=0)
  83. # discriminator with the smallest impurity.
  84. cur_min_disc = None
  85. # Choose a column to discriminate by.
  86. for x_idx, (dmin, dmax) in enumerate(zip(minxs, maxxs)):
  87. # Create a set of possibly values to use as the discriminator for that column.
  88. disc_vals = np.linspace(dmin, dmax, 1000)
  89. for disc_val in disc_vals:
  90. selection = xm[:, x_idx] < disc_val
  91. yleft = ym[selection]
  92. yright = ym[selection==False]
  93. # Calculate impurity.
  94. imp = disc_val_impurity(yleft, yright)
  95. # Choose a column with the smallest impurity.
  96. try:
  97. if cur_min_disc["impurity"] > imp["impurity"]:
  98. imp["discriminator"] = x_idx
  99. imp["val"] = disc_val
  100. cur_min_disc = imp
  101. except TypeError:
  102. imp["discriminator"] = x_idx
  103. imp["val"] = disc_val
  104. cur_min_disc = imp
  105. return cur_min_disc
  106. class BinaryTreeClassifier:
  107. def __init__(self, max_depth=4, min_data=5):
  108. """
  109. Initialise the binary decision tree classifier. This classifier works by:
  110. - Splitting the data into 2 sets at every node.
  111. - These 2 sets are then split into 2 more sets at their nodes etc. until they reach a leaf.
  112. - At the leaves, the data is classified into whatever class was "most common" in that leaf during training.
  113. :param max_depth: The maximum depth the binary tree classifier goes to.
  114. :param min_data: The minimum sample size of the training data before the tree stops splitting.
  115. """
  116. tree = dict()
  117. self.depth = max_depth
  118. self.min_data = min_data
  119. def _node_mask(X, node):
  120. """
  121. Get the discriminator mask for the node. This splits the data into left and right components.
  122. :param X: dataset input data.
  123. :param node: the current node of the tree, with its discriminator value.
  124. :return: truth array, which splits data left and right.
  125. """
  126. return X[:, node["discriminator"]] < node["val"]
  127. def _apply_disc(X, y, node):
  128. """
  129. Apply the discriminator to the data at a given node.
  130. :param X: dataset input.
  131. :param y: dataset (observed) output.
  132. :param node: The node to split data by.
  133. :return: The x and y data, split left and right.
  134. """
  135. left_cond = BinaryTreeClassifier._node_mask(X, node)
  136. right_cond = left_cond == False
  137. left_X, left_y = X[left_cond], y[left_cond]
  138. right_X, right_y = X[right_cond], y[right_cond]
  139. return left_X, left_y, right_X, right_y
  140. def _tree_node(X, y, max_depth, min_data):
  141. """
  142. Create a tree node. This also creates child nodes of this node recursively.
  143. :param X: input data for the dataset at a node.
  144. :param y: output (observed) data for the dataset at a node.
  145. :param max_depth: The maximum depth of the tree from this node.
  146. :param min_data: The minimum amount of data which can be discriminated.
  147. :return: The node + its children, as a dict.
  148. """
  149. # Get the new node, as a dict.
  150. node = niave_min_impurity(X, y)
  151. # Split the data using the discriminator.
  152. left_X, left_y, right_X, right_y = BinaryTreeClassifier._apply_disc(X, y, node)
  153. if max_depth > 1:
  154. if left_X.shape[0] >= min_data:
  155. # Create a new node on the left (recursively) if max depth
  156. # and min data have not been reached.
  157. node["left"] = BinaryTreeClassifier._tree_node(left_X, left_y, max_depth-1, min_data)
  158. if right_X.shape[0] >= min_data:
  159. # Create a new node on the right (recursively) if max depth
  160. # and min data have not been reached.
  161. node["right"] = BinaryTreeClassifier._tree_node(right_X, right_y, max_depth-1, min_data)
  162. return node
  163. def _run_tree(X, node):
  164. """
  165. Run a node of the classifier, recurisively.
  166. :param node: The node to run on the data.
  167. :return: The classified y (expected) data.
  168. """
  169. # Setup y array.
  170. y = np.zeros(X.shape[0])
  171. # Get the discriminator left conditional.
  172. left_cond = BinaryTreeClassifier._node_mask(X, node)
  173. # Right conditional
  174. right_cond = left_cond == False
  175. try:
  176. # Try to split the data further on the left side.
  177. y[left_cond] = BinaryTreeClassifier._run_tree(X[left_cond], node["left"])
  178. except KeyError:
  179. # If we cannot split any further, get the class of the data on the left (as this is a leaf).
  180. y[left_cond] = node["lmax_class"]
  181. try:
  182. # Try to split the data further on the right side.
  183. y[right_cond] = BinaryTreeClassifier._run_tree(X[right_cond], node["right"])
  184. except KeyError:
  185. # If we cannot split any further, get the class of the data on the right (as this is a leaf).
  186. y[right_cond] = node["rmax_class"]
  187. return y
  188. def _node_dict(node, idx=0):
  189. """
  190. Get a dict of all the nodes, recursively. The keys are the index of an array,
  191. as if the array is a heap.
  192. :param node: The current node to add to the dict and to get children of recursively.
  193. :param idx: current index (key) of the node.
  194. :return: dict containing all the nodes retrieved.
  195. """
  196. # Current nodes.
  197. nodes = {}
  198. node_data = {"lmax_class": node["lmax_class"],
  199. "rmax_class": node["rmax_class"],
  200. "discriminator": node["discriminator"],
  201. "val": node["val"]}
  202. nodes[idx] = node_data
  203. # Try to get the left nodes.
  204. try:
  205. left_idx = 2 * idx + 1
  206. nodes.update(BinaryTreeClassifier._node_dict(node["left"], left_idx))
  207. except KeyError:
  208. pass
  209. # Try to get the right nodes.
  210. try:
  211. right_idx = 2 * idx + 2
  212. nodes.update(BinaryTreeClassifier._node_dict(node["right"], right_idx))
  213. except KeyError:
  214. pass
  215. # return the dict of nodes retrieved.
  216. return nodes
  217. def build_tree(self, X, y):
  218. """
  219. Build (train) the decision tree classifier.
  220. :param X: input training data.
  221. :param y: output training (observed) data.
  222. :return: None
  223. """
  224. self.tree = BinaryTreeClassifier._tree_node(X, y, self.depth, self.min_data)
  225. def classify(self, X):
  226. """
  227. Classify some data using the tree.
  228. :param X: Input data.
  229. :return: output (expected) classes of the data, or y values, for the given input.
  230. """
  231. return BinaryTreeClassifier._run_tree(X, self.tree)
  232. def tree_to_heap_array(self):
  233. """
  234. Convert the tree to a binary heap, stored in an array with standard indexing.
  235. i.e. a node at index i has children at 2i*1 and 2i+2 and a parent at (i-1)//2.
  236. :return: list containing the tree nodes.
  237. """
  238. tree_dict = BinaryTreeClassifier._node_dict(self.tree)
  239. return [tree_dict[key] for key in sorted(tree_dict.keys())]
  240. def shuffle_split(x, y, frac=0.6):
  241. """
  242. Shuffle and split X and y data. "frac" is the ratio of the split.
  243. e.g. 0.6 means 60% of the data goes into the left fraction, 40% into the right.
  244. Note X and y are shuffled the same, so row i in X data is still matched with row i in y after shuffle.
  245. :param x: X values of the data (predictor).
  246. :param y: y values of the data (observation).
  247. :param frac: fraction to split data by.
  248. :return: x1, y1, x2, y2 data where x1, y1 is the left fraction and x2, y2 is the right.
  249. """
  250. data_idx = np.arange(x.shape[0])
  251. sample1 = data_idx < (data_idx.max()*frac)
  252. np.random.shuffle(data_idx)
  253. np.random.shuffle(sample1)
  254. sample2 = sample1 == False
  255. x1, y1 = x[data_idx[sample1]], y[data_idx[sample1]]
  256. x2, y2 = x[data_idx[sample2]], y[data_idx[sample2]]
  257. return x1, y1, x2, y2
  258. if __name__ == "__main__":
  259. # Set the seed for expected test results.
  260. np.random.seed(10)
  261. # Test decision tree with iris data.
  262. iris_data = datasets.load_iris()
  263. X = iris_data["data"]
  264. y = iris_data["target"]
  265. # Split iris data into test and train.
  266. X_train, y_train, X_test, y_test = shuffle_split(X, y)
  267. # create the decision tree classifier.
  268. classifier = BinaryTreeClassifier()
  269. # Train the classifier.
  270. classifier.build_tree(X_train, y_train)
  271. # Get the result when the classifier is applied to to the test data.
  272. result = classifier.classify(X_test)
  273. # Get the accuracy of the classifier.
  274. # accuracy = (number of correct results)/(total number of results)
  275. print("accuracy:", (result == y_test).sum()/(result.shape[0]))
  276. # convert the tree into a heap array.
  277. tree_arr = classifier.tree_to_heap_array()
  278. print("heap:")
  279. for i, node in enumerate(tree_arr):
  280. print(i, node)