Tags: ,

K-nearest neighbors algorithm in Python

Dr. Huidae Cho
import random
import numpy as np
import matplotlib.pyplot as plt

# define a distance function
def calc_dist(loc1, loc2):
    # extract x and y from loc1 and loc2
    x1, y1 = loc1 # loc1 = [x1, y1]
    x2, y2 = loc2 # loc2 = [x2, y2]
    return ((x1 - x2)**2 + (y1 - y2)**2)**0.5

def gen_animals(num_dogs, num_cats):
    # generate a sample dataset for training
    max_x = 100
    min_x = -max_x
    max_y = max_x
    min_y = min_x

    loc_dogs = []
    loc_cats = []

    for i in range(num_dogs):
        x = random.uniform(min_x, max_x)
        y = random.uniform(min_y, max_y)
        loc = [x, y]
        loc_dogs.append(loc)

    for i in range(num_cats):
        x = random.uniform(min_x, max_x)
        y = random.uniform(min_y, max_y)
        loc = [x, y]
        loc_cats.append(loc)

    return loc_dogs, loc_cats

def gen_unknown(x_unk, y_unk):
    loc_unk = [x_unk, y_unk]
    return loc_unk

def plot_animals(loc_dogs, loc_cats, loc_unk):
    x_dogs = [xy[0] for xy in loc_dogs]
    y_dogs = [xy[1] for xy in loc_dogs]

    x_cats = [xy[0] for xy in loc_cats]
    y_cats = [xy[1] for xy in loc_cats]

    plt.scatter(x_dogs, y_dogs, marker="o", color="blue")
    plt.scatter(x_cats, y_cats, marker="^", color="red")
    plt.scatter(x_unk, y_unk, marker="X", color="green")
    plt.show()

# implement KNN
def knn(loc_dogs, loc_cats, loc_unk, k):
    # calculate the distances from the unknown to all the dogs
    dist_dogs = []
    for i in range(num_dogs):
        dist_dogs.append(calc_dist(loc_unk, loc_dogs[i]))

    # calculate the distances from the unknown to all the cats
    dist_cats = []
    for i in range(num_cats):
        dist_cats.append(calc_dist(loc_unk, loc_cats[i]))

    # merge the distances into a single list
    dists = []

    # 0, 1, 2, 3, 4, ..., num_dogs - 1
    for i in range(num_dogs):
        dists.append(dist_dogs[i])

    # num_dogs, num_dogs + 1, ..., num_dogs + num_cats - 1
    for i in range(num_cats):
        dists.append(dist_cats[i])

    # find the sorted indices of dists
    sorted_idx = np.argsort(dists)

    vote_dogs = 0
    vote_cats = 0

    for i in range(k):
        if sorted_idx[i] < num_dogs:
            vote_dogs = vote_dogs + 1
        else:
            vote_cats = vote_cats + 1

    if vote_dogs > vote_cats:
        return "dog"
    elif vote_dogs < vote_cats:
        return "cat"
    elif random.random() < 0.5:
        return "dog (random)"
    else:
        return "cat (random)"

num_dogs = 20
num_cats = 25

x_unk = 10
y_unk = 0

k = 3

loc_dogs, loc_cats = gen_animals(num_dogs, num_cats)
loc_unk = gen_unknown(x_unk, y_unk)
plot_animals(loc_dogs, loc_cats, loc_unk)
print(knn(loc_dogs, loc_cats, loc_unk, k))