姓名:刘强
【嵌牛导读】
手写识别是计算机视觉的一个研究方向,可以看成是一个分类问题。机器学习的任务,便是解决分类(有监督学习)、聚类(无监督学习)和回归(强化学习)问题。k-近邻算法(简称kNN)是最简单的有监督学习算法,本文介绍了如何用k-近邻算法构建一个手写识别系统,并附上其python实现。
【嵌牛鼻子】
k-近邻算法 机器学习 分类 手写识别
【嵌牛提问】
k-近邻算法是什么? 如何构建一个手写识别系统?
【嵌牛正文】
k近邻算法基本思想
存在一个样本数据集,称为训练集,训练集中每个数据都存在标签(标签即数据所属的类别,从这一点可以看出,k近邻算法属于有监督学习)。对于不知道标签的新数据,将新数据的每个特征与训练集中数据对应的特征相比较,选出训练集中前k个最相似的数据(这就是k-近邻算法名称中k的出处),然后对这k个数据做统计,选择出现次数最多的标签作为新数据的标签(即k-近邻算法的输出)。
从其基本思想可以看出,k-近邻算法用于解决分类问题。所谓近邻,其实是用数据之间的欧氏距离来衡量它们的相似程度,距离越短,表示两个数据越相似。
构建手写识别系统
需求分析
很多输入法都支持手写输入,实现手写输入通常的做法是把手写的结果生成图片,进行图像识别。我们知道,图片可以用矩阵表示,对于单通道的灰度图像,假如分辨率为32X32,则可以用一个32X32的矩阵表示,矩阵中的每个元素表示图片中该位置的像素,元素的值为0~255之间的灰度值。
而对于手写图片,表示方法则更加简单,因为手写图片是只有黑白两色的二值图像,利用图像处理软件,黑色的位置写1,白色背景写0,将其转成文本文件,如下图所示:
虽然这样表示不能有效利用内存空间(本来0/1只需占据1bit的空间,但是变成字符“0”,“1”之后需要用char类型所占的字节数),但是对于图像到矩阵的转换这一过程非常直观,方便演示。
我们的目标是:将这样的一幅“图像”输入我们的系统,我们能够输出“图像”中所显示的数字(只做数字0~9的识别)。
系统组成
我们的手写识别系统由以下部分组成:
- 已知标签的训练集
- 文件输入输出模块
- kNN算法模块
已知标签的训练集
点此下载:用到的数据及源代码
其中,trainingDigits文件夹中存放的是用作训练集的的图片,其中包含了1934个训练样本,testDigits文件夹中存放的是用作测试集的图片,其中包含了946个测试样本。每个文件的文件名中含有它的标签。
文件输入输出模块
python读文本文件相当简单,为了迎合后续的kNN算法,我们不把图像表示成32X32的矩阵形式,而是将其转化成1X1024的向量,为此我们定义一个img2vector函数:
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
kNN算法模块
根据上述对kNN算法的描述,kNN算法有如下步骤:
- 测试数据与训练集中的每个数据进行比较,以这两个数据间的欧氏距离作为测试数据和训练数据间的相似性度量
- 将算出的欧式距离列表从小到大排序,取前k名所对应的训练集中的数据
- 取出这k个数据的标签,对数目进行统计,出现次数最多的标签作为算法的输出,即分类的结果
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
系统整体代码
'''
kNN: k Nearest Neighbors
Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number)
Output: the most popular class label
'''
from numpy import *
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') #load the training set
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
系统测试
测试环境
- win10 64位
- python3.6.2
测试步骤
- 打开cmd,进入kNN.py所在的文件夹
- 输入python进入python shell
- 输入from kNN import *导入kNN模块中所有函数
- 输入handwritingClassTest(),回车
测试结果
从测试结果来看,1.0571%的错误率,准确度还是蛮高的……
增加训练集的样本容量能有效提高系统的准确度,但是同时增加了运算量,使计算耗时增加。