ML Decision Tree classifier is only splitting on the same tree / asking about the same attribute

205 Views Asked by At

I am currently making a Decision tree classifier using Gini and Information Gain and splitting the tree based on the the best attribute with the most gain each time. However, it is sticking the same attribute every time and simply adjusting the value for its question. This results in a very low accuracy of usually around 30% as it is only taking into account the very first attribute.

Finding the best split

 # Used to find the best split for data among all attributes

def split(r):
    max_ig = 0
    max_att = 0
    max_att_val = 0
    i = 0

    curr_gini = gini_index(r)
    n_att = len(att)

    for c in range(n_att):
        if c == 3:
            continue

        c_vals = get_column(r, c)

        while i < len(c_vals):
            # Value of the current attribute that is being tested
            curr_att_val = r[i][c]
            true, false = fork(r, c, curr_att_val)
            ig = gain(true, false, curr_gini)

            if ig > max_ig:
                max_ig = ig
                max_att = c
                max_att_val = r[i][c]
            i += 1

    return max_ig, max_att, max_att_val

Compare to split the data down the true based on true or false

    # Used to compare and test if the current row is greater than or equal to the test value
# in order to split up the data

def compare(r, test_c, test_val):
    if r[test_c].isdigit():
        return r[test_c] == test_val

    elif float(r[test_c]) >= float(test_val):
        return True

    else:
        return False


# Splits the data into two lists for the true/false results of the compare test

def fork(r, c, test_val):
    true = []
    false = []

    for row in r:

        if compare(row, c, test_val):
            true.append(row)
        else:
            false.append(row)

    return true, false

Iterate through tree

def rec_tree(r):
ig, att, curr_att_val = split(r)

if ig == 0:
    return Leaf(r)

true_rows, false_rows = fork(r, att, curr_att_val)

true_branch = rec_tree(true_rows)
false_branch = rec_tree(false_rows)

return Node(att, curr_att_val, true_branch, false_branch)
1

There are 1 best solutions below

0
On BEST ANSWER

The working solution i have was to change the split function as follows. To be completly honest i amnt able to see whats wrong but it might be obvious The working function is as follows

def split(r):
max_ig = 0
max_att = 0
max_att_val = 0

# calculates gini for the rows provided
curr_gini = gini_index(r)
no_att = len(r[0])

# Goes through the different attributes

for c in range(no_att):

    # Skip the label column (beer style)

    if c == 3:
        continue
    column_vals = get_column(r, c)

    i = 0
    while i < len(column_vals):
        # value we want to check
        att_val = r[i][c]

        # Use the attribute value to fork the data to true and false streams
        true, false = fork(r, c, att_val)

        # Calculate the information gain
        ig = gain(true, false, curr_gini)

        # If this gain is the highest found then mark this as the best choice
        if ig > max_ig:
            max_ig = ig
            max_att = c
            max_att_val = r[i][c]
        i += 1

return max_ig, max_att, max_att_val