使用TensorFlow-Slim进行自己图像数据集分类

参考

https://lijiancheng0614.github.io/2017/06/29/2017_06_29_TensorFlow-Slim-image-classification/

数据集

将数据集按照类别整理为多个文件夹,生成labels.txt文件

预处理脚本

  • /home/user/zhaobing/teds/路径下
  • 执行creat_list.py用于生成list.txt来表示图片路径与标签的关系
  • 执行 creat_train_val_list.py用于随机生成训练集与验证集list文件list_train.txtlist_val.txt
    脚本中_NUM_VALIDATION = 350为验证集数目
  • 执行 creat_train_val_tfrecord.py用于生成训练/验证集tfrecord文件,保存在train和val文件夹中
    注意脚本中要将tedspn修改外自己的图片文件夹
convert_dataset('list_train.txt', 'tedspn', 'train/')
convert_dataset('list_val.txt', 'tedspn', 'val/')
  • labels.txt将所有类别防到此文件夹下
    labels.txt

slim文件路径

  • /home/user/anaconda3/envs/tf1.4/lib/python3.6/site-packages/tensorflow/models/research/slim
    涉及脚本文件
  • ./datasets/路径下添加dataset_classification.py文件
  • ./文件路径下修改train_image_classifier.pytrain_image_classifier_teds.py

训练时执行

python train_image_classifier_teds.py \
    --train_dir=train_logs \
    --dataset_dir=/home/user/zhaobing/teds/train \
    --num_samples=1930 \
    --num_classes=6 \
    --labels_to_names_path=/home/user/zhaobing/teds/labels.txt \
    --model_name=inception_resnet_v2 \
    --checkpoint_path=/home/user/zhaobing/teds/inception_resnet_v2_2016_08_30.ckpt \
    --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
    --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

注意常见错误

--train_dir=train_logs此目录位于/home/user/anaconda3/envs/tf1.4/lib/python3.6/site-packages/tensorflow/models/research/slim/train_logs下每次训练会将检查点文件放在此目录下,下次进行训练时,如果此文件夹下有检查点,会自动加载此检查点,导致tensor shape 错误,记得每次进行新的训练时将train_logs和eval_logs文件夹文件删除
--num_samples=1930为训练集的图像总数
--num_classes=6为分类类别数

  • ./文件路径下修改eval_image_classifier.pyeval_image_classifier_teds.py

验证时执行

python eval_image_classifier_teds.py \
    --checkpoint_path=train_logs \
    --eval_dir=eval_logs \
    --dataset_dir=/home/user/zhaobing/teds/val \
    --num_samples=350 \
    --num_classes=6 \
    --model_name=inception_resnet_v2

--num_samples=350为验证集数目
--num_classes=6为分类类别


选取图片进行推理测试的脚本

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import json
import math
import time
import numpy as np
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
    'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
    'checkpoint_path', '/tmp/tfmodel/',
    'The directory where the model was written to or an absolute path to a '
    'checkpoint file.')
tf.app.flags.DEFINE_string(
    'test_list', '', 'Test image list.')
tf.app.flags.DEFINE_string(
    'test_dir', '.', 'Test image directory.')
tf.app.flags.DEFINE_integer(
    'batch_size', 16, 'Batch size.')
tf.app.flags.DEFINE_integer(
    'num_classes', 5, 'Number of classes.')
tf.app.flags.DEFINE_integer(
    'labels_offset', 0,
    'An offset for the labels in the dataset. This flag is primarily used to '
    'evaluate the VGG and ResNet architectures which do not use a background '
    'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
    'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
    'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
    'test_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
    if not FLAGS.test_list:
        raise ValueError('You must supply the test list with --test_list')
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()
        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
            is_training=False)
        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=False)
        test_image_size = FLAGS.test_image_size or network_fn.default_image_size
        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path
        batch_size = FLAGS.batch_size
        tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])
        logits, _ = network_fn(tensor_input)
        
        probabilities2 = tf.nn.softmax(logits)
        
        logits = tf.nn.top_k(logits, 5)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        test_ids = [line.strip() for line in open(FLAGS.test_list)]
        tot = len(test_ids)
        results = list()
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)
            time_start = time.time()
            for idx in range(0, tot, batch_size):
                images = list()
                idx_end = min(tot, idx + batch_size)
                print(idx)
                for i in range(idx, idx_end):    
                    images = list()
                    image_id = test_ids[i]
                    test_path = os.path.join(FLAGS.test_dir, image_id)
                    image = open(test_path, 'rb').read()
                    print(i, test_path)
                    image = tf.image.decode_jpeg(image, channels=3)
                    processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
                    processed_image = sess.run(processed_image)
                    images.append(processed_image)
                    images = np.array(images)
                    predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices
                    #predictions = sess.run(logits, feed_dict = {tensor_input : images})
                    print((image_id, predictions.tolist()))   
                    #probabilities = sess.run(probabilities)
                    probabilities = sess.run(probabilities2, feed_dict = {tensor_input : images})
                    #print(image_id, probabilities)  
                    #print(probabilities2)
                    
                    probabilities = probabilities[0,0:]
                    sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),
                                            key=lambda x:x[1])]
                    for i in range(5):
                        index = sorted_inds[i]
                        print(probabilities[index])
                  
                
if __name__ == '__main__':
    tf.app.run()

测试时执行命令

python test.py \
    --checkpoint_path=./train_logs/ \   # 训练检查点所在文件夹
    --test_list=./test.txt \    # 写一个txt,将所有想进行推理检测图片的名字写入
    --test_dir=./tedspn/ \    # 推理图片所在路径
    --batch_size=16 \
    --num_classes=8 \     分类任务包含的种类
    --model_name=inception_resnet_v2    训练时选择的网络模型

执行后,效果如下:

0 ./tedspn/zhidongzhaqian/zz0034.jpg
('zhidongzhaqian/zz0034.jpg', [[7, 6, 2, 3, 5]])
0.50231224
0.46775842
0.018601459
0.006529054
0.002263928
1 ./tedspn/banluoshuan/bp0036.jpg
('banluoshuan/bp0036.jpg', [[1, 0, 4, 3, 5]])
0.819898
0.14817707
0.007549826
0.00671191
0.0063681067

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,383评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,522评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,852评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,621评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,741评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,929评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,076评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,803评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,265评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,582评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,716评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,395评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,039评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,027评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,488评论 2 361
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,612评论 2 350

推荐阅读更多精彩内容