Over a million developers have joined DZone.
{{announcement.body}}
{{announcement.title}}

K-Nearest Neighbour Classifier

DZone's Guide to

K-Nearest Neighbour Classifier

· Web Dev Zone ·
Free Resource

Bugsnag monitors application stability, so you can make data-driven decisions on whether you should be building new features, or fixing bugs. Learn more.

The Nearest Neighbour Classifier is one of the most straightforward classifier in the arsenal of machine learning techniques. It performs the classification by identifying the nearest neighbours to a query pattern and using those neighbors to determine the label of the query. The idea behind the algorithm is simple: Assign the query pattern to the class which occurs the most in the k nearest neighbors. In this post we'll use the function knn_search(...) that we have seen in the last post to implement a K-Nearest Neighbour Classifier. The implementation of the classifier is as follows:
from numpy import random,argsort,argmax,bincount,int_,array,vstack,round
from pylab import scatter,show

def knn_classifier(x, D, labels, K):
 """ Classify the vector x
     D - data matrix (each row is a pattern).
     labels - class of each pattern.
     K - number of neighbour to use.
     Returns the class label and the neighbors indexes.
 """
 neig_idx = knn_search(x,D,K)
 counts = bincount(labels[neig_idx]) # voting
 return argmax(counts),neig_idx
Let's test the classifier on some random data:
 # generating a random dataset with random labels
data = random.rand(2,150) # random points
labels = int_(round(random.rand(150)*1)) # random labels 0 or 1
x = random.rand(2,1) # random test point

# label assignment using k=5
result,neig_idx = knn_classifier(x,data,labels,5)
print 'Label assignment:', result

# plotting the data and the input pattern
# class 1, red points, class 0 blue points
scatter(data[0,:],data[1,:], c=labels,alpha=0.8)
scatter(x[0],x[1],marker='o',c='g',s=40)
# highlighting the neighbours
plot(data[0,neig_idx],data[1,neig_idx],'o',
  markerfacecolor='None',markersize=15,markeredgewidth=1)
show()
The script will show the following graph:



The query vector is represented with a green point and we can see that the 3 out of 5 nearest neighbors are red points (label 1) while the remaining 2 are blue (label 2).
The result of the classification will be printed on the console:

Label assignment: 1

As we expected, the green point have been assigned to the class with red markers.


Monitor application stability with Bugsnag to decide if your engineering team should be building new features on your roadmap or fixing bugs to stabilize your application.Try it free.

Topics:

Opinions expressed by DZone contributors are their own.

{{ parent.title || parent.header.title}}

{{ parent.tldr }}

{{ parent.urlSource.name }}