在PaddlePaddle中实现MNIST数据集训练:高层API

本文从MNIST数据集下载开始,详细介绍在PaddlePaddle中,基于高层API实现MNIST数据集训练

第一步:将MNIST数据下载到本地,下载链接:http://yann.lecun.com/exdb/mnist/,可以得到四个文件:

  • train-images-idx3-ubyte.gz: 训练集图像数据
  • train-labels-idx1-ubyte.gz: 训练集标签
  • t10k-images-idx3-ubyte.gz: 测试集图像数据
  • t10k-labels-idx1-ubyte.gz: 测试集标签

第二步:将下载的四个文件以Numpy ndarray类型载入内存。解压文件并读取数据的过程非常标准,大家可以直接用下面的范例程序

# train-images-idx3-ubyte 文件格式, 参考:http://yann.lecun.com/exdb/mnist/
'''
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 
0 means background (white), 255 means foreground (black).
'''
def load_images(image_file):
    # 读取*.gz格式文件
    with gzip.open(image_file) as f:
        buf = f.read()

    idx = 0
    # 读取文件信息
    magic, num_images, rows, cols = struct.unpack_from('>IIII', buf, idx)
    idx += struct.calcsize('>IIII')
    length = int(num_images*rows*cols)
    # 读取图像数据
    images = struct.unpack_from('>'+str(length)+'B', buf, idx)
    images = np.array(images).astype('float32')
    images = images.reshape(num_images, rows, cols)
    # 返回np.ndarray类型, N*r*c 图像数据
    return images


# train-labels-idx1-ubyte.gz 文件格式
'''
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.
'''
def load_labels(label_file):
    # 读取*.gz格式文件
    with gzip.open(label_file) as f:
        buf = f.read()
    # 读取文件信息
    idx = 0
    magic, num_labels = struct.unpack_from('>II', buf, idx)
    # 读取标签数据
    idx += struct.calcsize('>II')
    labels = struct.unpack_from('>'+str(num_labels)+'B',buf,idx)
 
    labels = np.array(labels).astype('int64')
    # 返回np.ndarray类型, 标签数据
    return labels

可以用下面的代码来测试图像数据的读入

# Test Code
import matplotlib.pyplot as plt
train_images = load_images('train-images-idx3-ubyte.gz')
test_images  = load_images('t10k-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz').reshape(-1,1)
test_labels  = load_labels('t10k-labels-idx1-ubyte.gz').reshape(-1,1)
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)
idx = 5
fig = plt.figure()
plt.subplot(1,2,1)
plt.imshow(train_images[idx],cmap='rainbow')
plt.subplot(1,2,2)
plt.imshow(test_images[idx],cmap='rainbow')
plt.show()

(60000, 28, 28) (60000,1) (10000, 28, 28) (10000,1)

展示数据

第三步:使用飞桨提供的paddle.io.Dataset基类,将数据封装为可迭代的数据源。

train_images = load_images('train-images-idx3-ubyte.gz')
test_images  = load_images('t10k-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz')
test_labels  = load_labels('t10k-labels-idx1-ubyte.gz')

# 图像数据归一化
train_images = train_images / 255.0
test_images  = test_images / 255.0

num_train_samples = train_images.shape[0]
num_test_samples = test_images.shape[0]

import paddle
from paddle.io import Dataset
class TrainDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = train_images[index]
        label = train_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

class TestDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = test_images[index]
        label = test_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

# 测试定义的数据集
train_dataset = TrainDataSet(num_train_samples)
test_dataset = TestDataSet(num_test_samples)

第四步:针对顺序的线性网络结构,使用飞桨提供的Sequential类来快速完成组网,这样可以减少类的定义等代码编写。

# 定义模型
mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

第五步:生成模型实例,并完成损失函数、优化方法和评估方法的配置。

# 预计模型结构生成模型实例,便于进行后续的配置、训练和验证
model = paddle.Model(mnist)

# 模型训练相关配置,准备损失计算方法,优化器和精度计算方法
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

最后一步:用fit()方法启动训练,evaluate()方法实现评估,predict()方法实现预测

# 开始模型训练
model.fit(train_dataset,
          epochs=5,
          batch_size=100,
          verbose=1)

# 用 evaluate 在测试集上对模型进行验证
eval_result = model.evaluate(test_dataset, verbose=0)
print(eval_result)
# 用 predict 在测试集上对模型进行测试
test_result = model.predict(test_dataset)

完整可运行的代码如下

import gzip 
import struct 
import numpy as np 

# train-images-idx3-ubyte 文件格式, 参考:http://yann.lecun.com/exdb/mnist/
'''
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 
0 means background (white), 255 means foreground (black).
'''
def load_images(image_file):
    # 读取*.gz格式文件
    with gzip.open(image_file) as f:
        buf = f.read()

    idx = 0
    # 读取文件信息
    magic, num_images, rows, cols = struct.unpack_from('>IIII', buf, idx)
    idx += struct.calcsize('>IIII')
    length = int(num_images*rows*cols)
    # 读取图像数据
    images = struct.unpack_from('>'+str(length)+'B', buf, idx)
    images = np.array(images).astype('float32')
    images = images.reshape(num_images, rows, cols)
    # 返回np.ndarray类型, N*r*c 图像数据
    return images


# train-labels-idx1-ubyte.gz 文件格式
'''
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.
'''
def load_labels(label_file):
    # 读取*.gz格式文件
    with gzip.open(label_file) as f:
        buf = f.read()
    # 读取文件信息
    idx = 0
    magic, num_labels = struct.unpack_from('>II', buf, idx)
    # 读取标签数据
    idx += struct.calcsize('>II')
    labels = struct.unpack_from('>'+str(num_labels)+'B',buf,idx)
 
    labels = np.array(labels).astype('int64')
    # 返回np.ndarray类型, 标签数据
    return labels

train_images = load_images('train-images-idx3-ubyte.gz')
test_images  = load_images('t10k-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz').reshape(-1,1)
test_labels  = load_labels('t10k-labels-idx1-ubyte.gz').reshape(-1,1)

# 图像数据归一化
train_images = train_images / 255.0
test_images  = test_images / 255.0

num_train_samples = train_images.shape[0]
num_test_samples = test_images.shape[0]

import paddle
from paddle.io import Dataset
class TrainDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = train_images[index]
        label = train_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

class TestDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = test_images[index]
        label = test_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

# 测试定义的数据集
train_dataset = TrainDataSet(num_train_samples)
test_dataset = TestDataSet(num_test_samples)

# 定义模型
mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

# 预计模型结构生成模型实例,便于进行后续的配置、训练和验证
model = paddle.Model(mnist)

# 模型训练相关配置,准备损失计算方法,优化器和精度计算方法
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 开始模型训练
model.fit(train_dataset,
          epochs=5,
          batch_size=100,
          verbose=1)

# 用 evaluate 在测试集上对模型进行验证
eval_result = model.evaluate(test_dataset, verbose=0)
print(eval_result)
# 用 predict 在测试集上对模型进行测试
test_result = model.predict(test_dataset)
# 展示预测结果
import matplotlib.pyplot as plt 
def show_img(img, predict):
    plt.title(f'predict:{predict}')
    plt.imshow(img.reshape([28,28]))
    plt.show()
idx = 2 #抽样索引
show_img(test_dataset[idx][0], np.argmax(test_result[0][idx]))

运行结果如下:
运行结果
预测值展示

下一节 《在PaddlePaddle中实现MNIST数据集训练:基础API》

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,539评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,911评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,337评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,723评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,795评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,762评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,742评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,508评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,954评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,247评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,404评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,104评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,736评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,352评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,557评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,371评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,292评论 2 352

推荐阅读更多精彩内容