用自己的图片构建cifar10 binary格式的数据
标签(空格分隔): 未分类
摘要
- 本文主要讨论用python构建cifar10 binary version数据
- 环境:windows, python3.6
- 本文最后更新日期是2018.04.13,因为版本差异造成的问题请见谅
正文
tensorflow教程中的卷积神经网络一节,利用cifar10数据做分类,效果还不错,那么我们能不能用这个模型去训练做其他的图片分类呢?
tensorflow教程的卷积神经网络教程中文网址
tensorflow教程的卷积神经网络教程英文网址
答案当然是可以,但是教程给的模型是利用cifar10 binary格式的,所以需要我们事先把我们的图片数据转成cifar10 binary格式。
教程给的模型代码
cifar10 dataset网址
cifar10 binary version如下:
由于cifar10的图片是32*32的,所以每张图片一共有1024个像素,按RGB分出来就是一共3072个byte,每张图片就可以写成(1+1024+1024+1024)的格式,其中第一个字节是label,而后1024个字节是Red通道,1024个字节是Green通道,1024个字节是Blue通道。将每张图片都按这种格式表示,无缝连接在一起,就构成了一个cifar10数据。
为此我参考这篇博文制作自己的python版本的类CIFAR10数据集,这篇博文是讲述制作python version的,将他的代码修改一下就可以制作binary version了。
# -*- coding: UTF-8 -*-
import cv2
import os
import numpy as np
DATA_LEN = 3072
CHANNEL_LEN = 1024
SHAPE = 32
def imread(im_path, shape=None, color="RGB", mode=cv2.IMREAD_UNCHANGED):
im = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)
if color == "RGB":
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if shape != None:
assert isinstance(shape, int)
im = cv2.resize(im, (shape, shape))
return im
def read_data(filename, data_path, shape=None, color='RGB'):
"""
filename (str): a file
data file is stored in such format:
image_name label
data_path (str): image data folder
return (numpy): a array of image and a array of label
"""
if os.path.isdir(filename):
print("Can't found data file!")
else:
f = open(filename)
lines = f.read().splitlines()
count = len(lines)
data = np.zeros((count, DATA_LEN), dtype=np.uint8)
# label = np.zeros(count, dtype=np.uint8)
lst = [ln.split(' ')[0] for ln in lines]
label = [int(ln.split(' ')[1]) for ln in lines]
idx = 0
s, c = SHAPE, CHANNEL_LEN
for ln in lines:
fname, lab = ln.split(' ')
im = imread(os.path.join(data_path, fname), shape=s, color='RGB')
'''
im = cv2.imread(os.path.join(data_path, fname), cv2.IMREAD_UNCHANGED)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (s, s))
'''
data[idx, :c] = np.reshape(im[:, :, 0], c)
data[idx, c:2 * c] = np.reshape(im[:, :, 1], c)
data[idx, 2 * c:] = np.reshape(im[:, :, 2], c)
label[idx] = int(lab)
idx = idx + 1
return data, label, lst
def py2bin(data, label):
label_arr = np.array(label).reshape(len(label), 1)
label_uint8 = label_arr.astype(np.uint8)
arr = np.hstack((label_uint8, data))
with open('./bin/train_batch', 'wb') as f:
for element in arr.flat:
f.write(element)
def imagelist():
directory_normal = r"data/normal" #原始图片位置,32*32 pixel
file_train_list = r"data/image_train_list.txt" #构建imagelist输出位置
with open(file_train_list, "a") as f:
for filename in os.listdir(directory_normal):
f.write(filename + " " + "0" + "\n") #这里分类默认全为0
if __name__ == '__main__':
data_path = './data/normal'
file_list = './data/image_train_list.txt'
save_path = './bin'
imagelist() #构建imagelist
data, label, lst = read_data(file_list, data_path, shape=32) #将图片像素数据转成矩阵和标签列表
py2bin(data, label) #将像素矩阵和标签列表转成cifar10 binary version