内容
- k近邻分类算法
- 从文本文件中解析和导入数据
- 使用Matplotlib创建扩展图
- 归一化数据
代码
kNN.py
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import numpy as np
import operator
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
from os import listdir
def createDataSet():
group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX,(dataSetSize,1)) - dataSet
sqdiffMat = diffMat**2
sqDistances = sqdiffMat.sum(axis=1)
Distances = sqDistances**.5
SortedIndices = Distances.argsort()
classCount = {}
for i in range(k):
label = labels[SortedIndices[i]]
classCount[label] = classCount.get(label,0) + 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def file2matrix(filename):
'''parse file'''
fr = open(filename)
arrayOfLines = fr.readlines()
numberOfLines = len(arrayOfLines)
returnMat = np.zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in arrayOfLines:
line = line.strip()
lineFromLine = line.split('\t')
returnMat[index,:] = lineFromLine[0:3]
classLabelVector.append(int(lineFromLine[-1]))
index += 1
return returnMat, classLabelVector
datingDataMat, dataLabels = file2matrix('datingTestSet2.txt')
dataLabels = np.array(dataLabels)
fig = plt.figure()
ax1 = fig.add_subplot(1,2,1)
ax1.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*dataLabels,15.0*dataLabels)
ax2 = fig.add_subplot(1,2,2)
ind1 = np.where(dataLabels == 1)
ax2.scatter(datingDataMat[ind1,0],datingDataMat[ind1,1],s=20,c='r',marker='*',label=1)
ind2 = np.where(dataLabels == 2)
ax2.scatter(datingDataMat[ind2,0],datingDataMat[ind2,1],s=10,c='b',marker='o',label=2)
ind3 = np.where(dataLabels == 3)
ax2.scatter(datingDataMat[ind3,0],datingDataMat[ind3,1],s=30,c='g',marker='+',label=3)
ax2.set_xlabel('玩视频游戏所耗时间百分比',fontproperties=font)
ax2.set_ylabel('每周消费的冰淇淋公升数',fontproperties=font)
plt.legend()
plt.show()
def autoNorm(dataset):
minVals = dataset.min(0)
maxVals = dataset.max(0)
ranges = maxVals - minVals
r = dataset.shape[0]
normDataSet = (dataset - minVals) / ranges
return normDataSet, ranges, minVals
normMat, ranges, minVals = autoNorm(datingDataMat)
def datingClassTest():
testRatio = 0.1
datingDataMat, dataLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
r = normMat.shape[0]
numTestVecs = int(r*testRatio)
errorCount = 0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],dataLabels[numTestVecs:r],3)
print('The classifier came back with: %d,the real answer is: %d'%(classifierResult,dataLabels[i]))
if classifierResult != dataLabels[i]:
errorCount += 1
print('the total error rate is: %f'%(errorCount/numTestVecs))
datingClassTest()
def classifyPerson():
resultList = ['not at all','in small doses','in large doses']
percentTats = float(input('percentage of time spent playing video games?'))
ffMiles = float(input('frequent flier miles earned per year?'))
iceCream = float(input('liters of ice cream consumed per year?'))
datingDataMat, dataLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles,percentTats,iceCream])
classifierResult = classify0((inArr - minVals)/ranges,normMat,dataLabels,3)
print('You will probably like this person: %s'%resultList[classifierResult - 1])
classifyPerson()
def img2vector(filename):
'''parse img'''
returnVect = np.zeros((1,1024))
try:
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
finally:
fr.close()
return returnVect
testVct = img2vector('testDigits/0_13.txt')
print(testVct[0,0:31])
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s'%(fileNameStr))
testFileList = listdir('testDigits')
errorCount = 0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s'%(fileNameStr))
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
print('the classifier came back with: %d, the real answer is: %d'%(classifierResult,classNumStr))
if classifierResult != classNumStr:
errorCount += 1
print('\nthe total number of errors is: %d'%errorCount)
print('\nthe total error rate is: %f'%(errorCount/mTest))
handwritingClassTest()