参考
https://lijiancheng0614.github.io/2017/06/29/2017_06_29_TensorFlow-Slim-image-classification/
数据集
将数据集按照类别整理为多个文件夹,生成labels.txt文件
预处理脚本
-
/home/user/zhaobing/teds/
路径下 - 执行
creat_list.py
用于生成list.txt来表示图片路径与标签的关系 - 执行
creat_train_val_list.py
用于随机生成训练集与验证集list文件list_train.txt
和list_val.txt
脚本中_NUM_VALIDATION = 350
为验证集数目 - 执行
creat_train_val_tfrecord.py
用于生成训练/验证集tfrecord文件,保存在train和val文件夹中
注意脚本中要将tedspn
修改外自己的图片文件夹
convert_dataset('list_train.txt', 'tedspn', 'train/')
convert_dataset('list_val.txt', 'tedspn', 'val/')
-
labels.txt
将所有类别防到此文件夹下
slim文件路径
-
/home/user/anaconda3/envs/tf1.4/lib/python3.6/site-packages/tensorflow/models/research/slim
涉及脚本文件 - 在
./datasets/
路径下添加dataset_classification.py
文件 - 在
./
文件路径下修改train_image_classifier.py
为train_image_classifier_teds.py
训练时执行
python train_image_classifier_teds.py \
--train_dir=train_logs \
--dataset_dir=/home/user/zhaobing/teds/train \
--num_samples=1930 \
--num_classes=6 \
--labels_to_names_path=/home/user/zhaobing/teds/labels.txt \
--model_name=inception_resnet_v2 \
--checkpoint_path=/home/user/zhaobing/teds/inception_resnet_v2_2016_08_30.ckpt \
--checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
--trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
注意常见错误
--train_dir=train_logs
此目录位于/home/user/anaconda3/envs/tf1.4/lib/python3.6/site-packages/tensorflow/models/research/slim/train_logs
下每次训练会将检查点文件放在此目录下,下次进行训练时,如果此文件夹下有检查点,会自动加载此检查点,导致tensor shape 错误,记得每次进行新的训练时将train_logs和eval_logs文件夹文件删除
--num_samples=1930
为训练集的图像总数
--num_classes=6
为分类类别数
- 在
./
文件路径下修改eval_image_classifier.py
为eval_image_classifier_teds.py
验证时执行
python eval_image_classifier_teds.py \
--checkpoint_path=train_logs \
--eval_dir=eval_logs \
--dataset_dir=/home/user/zhaobing/teds/val \
--num_samples=350 \
--num_classes=6 \
--model_name=inception_resnet_v2
--num_samples=350
为验证集数目
--num_classes=6
为分类类别
选取图片进行推理测试的脚本
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import json
import math
import time
import numpy as np
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'test_list', '', 'Test image list.')
tf.app.flags.DEFINE_string(
'test_dir', '.', 'Test image directory.')
tf.app.flags.DEFINE_integer(
'batch_size', 16, 'Batch size.')
tf.app.flags.DEFINE_integer(
'num_classes', 5, 'Number of classes.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'test_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.test_list:
raise ValueError('You must supply the test list with --test_list')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
is_training=False)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
test_image_size = FLAGS.test_image_size or network_fn.default_image_size
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
batch_size = FLAGS.batch_size
tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])
logits, _ = network_fn(tensor_input)
probabilities2 = tf.nn.softmax(logits)
logits = tf.nn.top_k(logits, 5)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
test_ids = [line.strip() for line in open(FLAGS.test_list)]
tot = len(test_ids)
results = list()
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
time_start = time.time()
for idx in range(0, tot, batch_size):
images = list()
idx_end = min(tot, idx + batch_size)
print(idx)
for i in range(idx, idx_end):
images = list()
image_id = test_ids[i]
test_path = os.path.join(FLAGS.test_dir, image_id)
image = open(test_path, 'rb').read()
print(i, test_path)
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_image = sess.run(processed_image)
images.append(processed_image)
images = np.array(images)
predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices
#predictions = sess.run(logits, feed_dict = {tensor_input : images})
print((image_id, predictions.tolist()))
#probabilities = sess.run(probabilities)
probabilities = sess.run(probabilities2, feed_dict = {tensor_input : images})
#print(image_id, probabilities)
#print(probabilities2)
probabilities = probabilities[0,0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),
key=lambda x:x[1])]
for i in range(5):
index = sorted_inds[i]
print(probabilities[index])
if __name__ == '__main__':
tf.app.run()
测试时执行命令
python test.py \
--checkpoint_path=./train_logs/ \ # 训练检查点所在文件夹
--test_list=./test.txt \ # 写一个txt,将所有想进行推理检测图片的名字写入
--test_dir=./tedspn/ \ # 推理图片所在路径
--batch_size=16 \
--num_classes=8 \ 分类任务包含的种类
--model_name=inception_resnet_v2 训练时选择的网络模型
执行后,效果如下:
0 ./tedspn/zhidongzhaqian/zz0034.jpg
('zhidongzhaqian/zz0034.jpg', [[7, 6, 2, 3, 5]])
0.50231224
0.46775842
0.018601459
0.006529054
0.002263928
1 ./tedspn/banluoshuan/bp0036.jpg
('banluoshuan/bp0036.jpg', [[1, 0, 4, 3, 5]])
0.819898
0.14817707
0.007549826
0.00671191
0.0063681067