import numpy as np
from IPython.display import HTML, display
import tabulate
import matplotlib.pyplot as plt
# toy datast of whether or not it will be rainy or sunny
feature_names = ["Humidity (%)", "Pressure (kPa)"]
data = [[29, 101.7], [60, 98.6], [40, 101.1], [62, 99.9], [39, 103.2], [51, 97.6], [46, 102.1], [55, 100.2]]
labels = ["Sun","Rain","Sun","Rain","Sun","Rain","Sun","Rain"]
# display table
table_labels = np.array(['class']+feature_names).reshape((1, 1+len(feature_names)))
table_data = np.concatenate([np.array(labels).reshape(len(data), 1), data], axis=1)
table_full = np.concatenate([table_labels, table_data], axis=0)
display(HTML(tabulate.tabulate(table_full, tablefmt='html')))
我们将介绍一个简单的用于分类问题的算法——K最近邻分类法(KNN)。首先,我们通过导入一个更实际的数据集Iris来扩大我们的问题。Iris是一个包含150个鸢尾花基因样本的数据,分为3个不同的种类山鸢尾(Iris setosa)、维吉尼亚鸢尾(Iris virginica)、杂色鸢尾(Iris versicolor)。每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于三个种类中的哪一类。
import numpy as np
from sklearn.datasets import load_iris
# load iris and grab our data and labels
iris = load_iris()
labels, data = iris.target, iris.data
num_samples = len(labels) # size of our dataset
num_features = len(iris.feature_names) # number of columns/variables
# shuffle the dataset
shuffle_order = np.random.permutation(num_samples)
data = data[shuffle_order, :]
labels = labels[shuffle_order]
label_names = np.array([iris.target_names[l] for l in labels])
table_labels = np.array(['class']+iris.feature_names).reshape((1, 1+num_features))
class_names = iris.target_names
table_data = np.concatenate([np.array(label_names).reshape(num_samples, 1), data], axis=1)[0:20]
# display table
table_full = np.concatenate([table_labels, table_data], axis=0)
display(HTML(tabulate.tabulate(table_full, tablefmt='html')))
# plot the original data
x, y, lab = data[:, 0], data[:, 1], labels
plt.figure(figsize=(8, 6))
plt.scatter(x, y, c=lab)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Iris dataset')
new_x, new_y = 6.5, 3.7
# plot the original data
x, y, lab = data[:, 0], data[:, 1], labels
plt.figure(figsize=(8, 6))
plt.scatter(x, y, c=lab)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Iris dataset')
# put the new point on top
plt.scatter(new_x, new_y, c='grey', cmap=None, edgecolor='k')
plt.annotate('?', (new_x+0.45, new_y+0.25), fontsize=20, horizontalalignment='center', verticalalignment='center')
plt.annotate("", xytext=(new_x+0.4, new_y+0.2), xy=(new_x+0.05, new_y), arrowprops=dict(arrowstyle="->"))
# calculate the distance between the new point and each of the points in our labeled dataset# calcu
distances = np.sum((data[:,0:2] - [new_x, new_y])**2, axis=1)
# find the index of the point whose distance is lowest
closest_point = np.argmin(distances)
# take its label
new_label = labels[closest_point]
print('Predicted label: %d'%new_label)
Predicted label: 2
# append the newly labeled point in our dataset
x = np.append(x, new_x)
y = np.append(y, new_y)
lab = np.append(lab, new_label)
# scatter plot as before
plt.figure(figsize=(8, 6))
plt.scatter(x, y, c=lab)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Iris dataset')
plt.annotate("", xytext=(x[closest_point]+0.02, y[closest_point]+0.02), xy=(new_x-0.02, new_y-0.02), arrowprops=dict(arrowstyle="->"))