How does this code for BernoulliNB classifier works?

51 Views Asked by At

Can someone please explain what this code does? It's from the book "Introduction to Machine Learning with Python" on Bernoulli Naive Bayes classifier:

counts = {}
for label in np.unique(y):
    # iterate over each class
    # count (sum) entries of 1 per feature
    counts[label] = X[y == label].sum(axis=0)
print("Feature counts:\n", counts)

I don't understand what happens on line5.

2

There are 2 best solutions below

0
chrslg On BEST ANSWER

Let's use an example.

import numpy as np
X=np.array([[1,10],[2,20], [3,30], [4,40], [5,50], [6,60]])
y=np.array([11,22,33,11,22,11])

np.unique(y) is [11,22,33]

So label will successively be those.

When label is 11
y==label is [11,22,33,11,22,11]==11 which is [True,False,False,True,False,True]
so X[y==label] is X[[True,False,False,True,False,True]] so it is a selection of rows 0, 3 and 5 of X. So [[1,10],[4,40],[6,60]]
sum(axis=0) sum that along axis 0, so X[y==label].sum(axis=0) is [1+4+6,10+40,60] = [11,110]
so counts[11]=[11,110]

Likewise, when label is 22, y==label is [False,True,False,False,True,False], so X[y==label] is [[2,20],[5,50]] so X[y==label].sum(axis=0) is [7,70], which is affected to counts[22].

And when label is 33, y==label is just [False,False,True,False,False,False], so X[y==label] is [[3,30]] so X[y==label].sum(axis=0) is [3,30] which is affected to counts[33].

So at the end, if your X data are a list of k values, and y data a list of k classes, chosen among n possibilities, counts are, for each n possible classes, the k sums of the values of data matching that class.

0
UdonN00dle On

Based on the short code snippet, it seems like the variable X is a data frame, and when you do X[y == label], it is filtering the data frame based on the condition where column y matches with the label.

Proceeding to the .sum(), it is taking the sum of the rows where column y has the value that matches the label. If this column happens to be a multi-dimensional array, it is taking the sum on the first axis, i.e. vertically.

Finally on the left hand side of the equal sign, it is adding this sum to the counts dictionary where the key is the label.