Browse Source

Python simple decision tree classifier

Code for a simple decision tree classifier has been added. This
is demonstrated to work with the classic iris dataset.
pull/32/head
James Sherratt 5 years ago
parent
commit
66bb46261e
1 changed files with 206 additions and 0 deletions
  1. +206
    -0
      ML Cookbook/DecisionTree.py

+ 206
- 0
ML Cookbook/DecisionTree.py View File

@ -0,0 +1,206 @@
"""
:Author: james
:Date: 21/10/2019
:License: MIT
:name: DecisionTree.py
Basic implementation of a binary decision tree algorithm, with one
discriminant per node.
"""
import numpy as np
from sklearn import datasets
def proportion_k(ym):
"""
Get the proportions of each class in the current set of values.
:param ym: y values (class) of the data at a given node.
:return:
"""
counts = list(np.unique(ym, return_counts=True))
counts[1] = counts[1]/(ym.shape[0])
return counts
def gini(k_proportions):
"""
Gini impurity function.
:param k_proportions:
:return:
"""
return (k_proportions*(1-k_proportions)).sum()
def node_impurity(ym):
"""
Calculate the impurity of data at a given node of the tree.
:param ym:
:return:
"""
if ym.shape[0] == 0:
return {"impurity": 0, "max_group": 0}
k_prop = proportion_k(ym)
return {"impurity": gini(k_prop[1]), "max_group": k_prop[0][np.argmax(k_prop[1])]}
def disc_val_impurity(yleft, yright):
"""
Calculate the level of impurity left in the given data split.
:param yleft:
:param yright:
:return:
"""
nleft = yleft.shape[0]
nright = yright.shape[0]
ntot = nleft + nright
left_imp = node_impurity(yleft)
right_imp = node_impurity(yright)
return {
"impurity": ((nleft/ntot)*left_imp["impurity"])+((nright/ntot)*right_imp["impurity"]),
"lmax_group": left_imp["max_group"],
"rmax_group": right_imp["max_group"]
}
def niave_min_impurity(xm, ym):
minxs = xm.min(axis=0)
maxxs = xm.max(axis=0)
# discriminator with the smallest impurity.
cur_min_disc = None
for x_idx, (dmin, dmax) in enumerate(zip(minxs, maxxs)):
disc_vals = np.linspace(dmin, dmax, 10)
for disc_val in disc_vals:
selection = xm[:, x_idx] < disc_val
yleft = ym[selection]
yright = ym[selection==False]
imp = disc_val_impurity(yleft, yright)
try:
if cur_min_disc["impurity"] > imp["impurity"]:
imp["discriminator"] = x_idx
imp["val"] = disc_val
cur_min_disc = imp
except TypeError:
imp["discriminator"] = x_idx
imp["val"] = disc_val
cur_min_disc = imp
return cur_min_disc
class BinaryTreeClassifier:
def __init__(self, max_depth=4, min_data=5):
tree = dict()
self.depth = max_depth
self.min_data = min_data
def _node_mask(X, node):
return X[:, node["discriminator"]] < node["val"]
def _apply_disc(X, y, node):
left_cond = BinaryTreeClassifier._node_mask(X, node)
right_cond = left_cond == False
left_X, left_y = X[left_cond], y[left_cond]
right_X, right_y = X[right_cond], y[right_cond]
return left_X, left_y, right_X, right_y
def _tree_node(X, y, max_depth, min_data):
node = niave_min_impurity(X, y)
left_X, left_y, right_X, right_y = BinaryTreeClassifier._apply_disc(X, y, node)
if max_depth > 0:
if left_X.shape[0] >= min_data:
node["left"] = BinaryTreeClassifier._tree_node(left_X, left_y, max_depth-1, min_data)
if right_X.shape[0] >= min_data:
node["right"] = BinaryTreeClassifier._tree_node(right_X, right_y, max_depth-1, min_data)
return node
def _run_tree(X, node):
y = np.zeros(X.shape[0])
left_cond = BinaryTreeClassifier._node_mask(X, node)
right_cond = left_cond == False
try:
y[left_cond] = BinaryTreeClassifier._run_tree(X[left_cond], node["left"])
except KeyError:
y[left_cond] = node["lmax_group"]
try:
y[right_cond] = BinaryTreeClassifier._run_tree(X[right_cond], node["right"])
except KeyError:
y[right_cond] = node["rmax_group"]
return y
def _node_dict(node, idx=0):
nodes = {}
node_data = {"lmax_group": node["lmax_group"],
"rmax_group": node["rmax_group"],
"discriminator": node["discriminator"],
"val": node["val"]}
nodes[idx] = node_data
try:
left_idx = 2 * idx + 1
nodes.update(BinaryTreeClassifier._node_dict(node["left"], left_idx))
except KeyError:
pass
try:
right_idx = 2 * idx + 2
nodes.update(BinaryTreeClassifier._node_dict(node["right"], right_idx))
except KeyError:
pass
return nodes
def build_tree(self, X, y):
self.tree = BinaryTreeClassifier._tree_node(X, y, self.depth, self.min_data)
def classify(self, X):
return BinaryTreeClassifier._run_tree(X, self.tree)
def tree_to_heap_array(self):
tree_dict = BinaryTreeClassifier._node_dict(self.tree)
return [tree_dict[key] for key in sorted(tree_dict.keys())]
def shuffle_split(x, y, frac=0.6):
"""
Shuffle and split X and y data.
:param x:
:param y:
:param frac:
:return:
"""
data_idx = np.arange(x.shape[0])
sample1 = data_idx < (data_idx.max()*frac)
np.random.shuffle(data_idx)
np.random.shuffle(sample1)
sample2 = sample1 == False
x1, y1 = x[data_idx[sample1]], y[data_idx[sample1]]
x2, y2 = x[data_idx[sample2]], y[data_idx[sample2]]
return x1, y1, x2, y2
if __name__ == "__main__":
np.random.seed(10)
iris_data = datasets.load_iris()
X = iris_data["data"]
y = iris_data["target"]
X_train, y_train, X_test, y_test = shuffle_split(X, y)
classifier = BinaryTreeClassifier()
classifier.build_tree(X_train, y_train)
result = classifier.classify(X_test)
print("accuracy:", (result == y_test).sum()/(result.shape[0]))
tree_arr = classifier.tree_to_heap_array()
pass

Loading…
Cancel
Save