K-nearest neighbors algorithm in Python
Dr. Huidae Cho
import random
import numpy as np
import matplotlib.pyplot as plt
# define a distance function
def calc_dist(loc1, loc2):
# extract x and y from loc1 and loc2
x1, y1 = loc1 # loc1 = [x1, y1]
x2, y2 = loc2 # loc2 = [x2, y2]
return ((x1 - x2)**2 + (y1 - y2)**2)**0.5
def gen_animals(num_dogs, num_cats):
# generate a sample dataset for training
max_x = 100
min_x = -max_x
max_y = max_x
min_y = min_x
loc_dogs = []
loc_cats = []
for i in range(num_dogs):
x = random.uniform(min_x, max_x)
y = random.uniform(min_y, max_y)
loc = [x, y]
loc_dogs.append(loc)
for i in range(num_cats):
x = random.uniform(min_x, max_x)
y = random.uniform(min_y, max_y)
loc = [x, y]
loc_cats.append(loc)
return loc_dogs, loc_cats
def gen_unknown(x_unk, y_unk):
loc_unk = [x_unk, y_unk]
return loc_unk
def plot_animals(loc_dogs, loc_cats, loc_unk):
x_dogs = [xy[0] for xy in loc_dogs]
y_dogs = [xy[1] for xy in loc_dogs]
x_cats = [xy[0] for xy in loc_cats]
y_cats = [xy[1] for xy in loc_cats]
plt.scatter(x_dogs, y_dogs, marker="o", color="blue")
plt.scatter(x_cats, y_cats, marker="^", color="red")
plt.scatter(x_unk, y_unk, marker="X", color="green")
plt.show()
# implement KNN
def knn(loc_dogs, loc_cats, loc_unk, k):
# calculate the distances from the unknown to all the dogs
dist_dogs = []
for i in range(num_dogs):
dist_dogs.append(calc_dist(loc_unk, loc_dogs[i]))
# calculate the distances from the unknown to all the cats
dist_cats = []
for i in range(num_cats):
dist_cats.append(calc_dist(loc_unk, loc_cats[i]))
# merge the distances into a single list
dists = []
# 0, 1, 2, 3, 4, ..., num_dogs - 1
for i in range(num_dogs):
dists.append(dist_dogs[i])
# num_dogs, num_dogs + 1, ..., num_dogs + num_cats - 1
for i in range(num_cats):
dists.append(dist_cats[i])
# find the sorted indices of dists
sorted_idx = np.argsort(dists)
vote_dogs = 0
vote_cats = 0
for i in range(k):
if sorted_idx[i] < num_dogs:
vote_dogs = vote_dogs + 1
else:
vote_cats = vote_cats + 1
if vote_dogs > vote_cats:
return "dog"
elif vote_dogs < vote_cats:
return "cat"
elif random.random() < 0.5:
return "dog (random)"
else:
return "cat (random)"
num_dogs = 20
num_cats = 25
x_unk = 10
y_unk = 0
k = 3
loc_dogs, loc_cats = gen_animals(num_dogs, num_cats)
loc_unk = gen_unknown(x_unk, y_unk)
plot_animals(loc_dogs, loc_cats, loc_unk)
print(knn(loc_dogs, loc_cats, loc_unk, k))