from math import log
import numpy as np
import pandas as pd
gMinSubsetSize = 1
class ID3Node:
def __init__(self,pType=None,pFeature=None,pData=None,pGain=None,pParent=None):
self.type = pType
self.data = pData
self.splitFeature = pFeature
self.gain = pGain
self.parent = pParent
self.childs = []
def calcEntropy(pDataset,pFeature):
dataset = pd.DataFrame(pDataset)
classInfo = dataset[pFeature].value_counts()
'''
classLabel = classLabel.drop_duplicates()
totalSamples = dataset.shape[0]
for i,c in classInfo.iteritems():
p = c * 1.0 / totalSamples
entropy += np.log2(p)*p
'''
classInfo = classInfo / dataset.shape[0]
return -np.log2(classInfo).dot(classInfo)
def findBestSplitFeature(pDataset):
dataset = pd.DataFrame(pDataset)
entropyOfDataset = calcEntropy(dataset,"CLASS")
maxGain = 0.0
splitFeature = ""
features = dataset.columns.drop("CLASS")
for featureIter in features.values:
entropy = 0.0
featureClasses = dataset[featureIter].value_counts()
for ix,num in featureClasses.iteritems():
subset = dataset.loc[dataset[featureIter] == ix]
entropy += (num * 1.0 / dataset.shape[0]) * calcEntropy(subset,"CLASS")
splitGain = entropyOfDataset - entropy
if maxGain < splitGain:
maxGain = splitGain
splitFeature = featureIter
return splitFeature,maxGain
def createID3Tree(pDataSet,pParent):
dataSet = pd.DataFrame(pDataSet)
labels = dataSet["CLASS"].value_counts()
if labels.shape[0] == 1 or dataSet.shape[0] <= gMinSubsetSize:
leafNode = ID3Node("leaf", "leaf", dataSet, None, pParent)
return leafNode
bestSplitFeature,gain = findBestSplitFeature(dataSet)
featureClasses = dataSet[bestSplitFeature].value_counts()
node = ID3Node("path",bestSplitFeature,None,None,pParent)
for ix, num in featureClasses.iteritems():
subset = dataSet.loc[dataSet[bestSplitFeature] == ix]
if len(subset.columns) < 2:
leafNode = ID3Node("leaf", "leaf", subset, None, pParent)
return leafNode
else:
subset.drop(bestSplitFeature,axis=1)
node.childs.append(createID3Tree(subset,node))
return node
def printID3Tree(pRoot):
if pRoot.type != "leaf":
print pRoot.splitFeature
for i in pRoot.childs:
printID3Tree(i)
else:
print(pRoot.data)
if __name__ == "__main__":
dataset = pd.read_csv("data3.csv")
dataset.drop("Unnamed: 0",axis=1,inplace=True)
dataset.replace("?",np.nan,inplace=True)
dataset.dropna(how="any",axis=1,inplace=True)
printID3Tree(createID3Tree(dataset,None))