TensorFlow学习笔记:Retrain Inception_v3(二)


接前文:TensorFlow学习笔记:Retrain Inception_v3(一)


5. 代码改写

通过阅读源码可以发现,程序的超参数都是通过命令行传入,当然,每个命令行参数都含有默认值。如果我们想要直接运行,稍微改动一下代码的最后一个部分中命令行参数的默认地址即可。将模型下载、数据集、文件存贮地址修改为特定的地址:

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--image_dir',
      type=str,
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/flower_photos',
      #default='',
      help='Path to folders of labeled images.'
  )
  parser.add_argument(
      '--output_graph',
      type=str,
      #default='/tmp/output_graph.pb',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/output_graph.pb',
      help='Where to save the trained graph.'
  )
  parser.add_argument(
      '--intermediate_output_graphs_dir',
      type=str,
      #default='/tmp/intermediate_graph/',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/intermediate_graph/',
      help='Where to save the intermediate graphs.'
  )
  parser.add_argument(
      '--intermediate_store_frequency',
      type=int,
      default=0,
      help="""\
         How many steps to store intermediate graph. If "0" then will not
         store.\
      """
  )
  parser.add_argument(
      '--output_labels',
      type=str,
      #default='/tmp/output_labels.txt',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/output_labels.txt',
      help='Where to save the trained graph\'s labels.'
  )
  parser.add_argument(
      '--summaries_dir',
      type=str,
      #default='/tmp/retrain_logs',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/retrain_logs',
      help='Where to save summary logs for TensorBoard.'
  )
  parser.add_argument(
      '--how_many_training_steps',
      type=int,
      default=4000,
      help='How many training steps to run before ending.'
  )
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='How large a learning rate to use when training.'
  )
  parser.add_argument(
      '--testing_percentage',
      type=int,
      default=10,
      help='What percentage of images to use as a test set.'
  )
  parser.add_argument(
      '--validation_percentage',
      type=int,
      default=10,
      help='What percentage of images to use as a validation set.'
  )
  parser.add_argument(
      '--eval_step_interval',
      type=int,
      default=10,
      help='How often to evaluate the training results.'
  )
  parser.add_argument(
      '--train_batch_size',
      type=int,
      default=100,
      help='How many images to train on at a time.'
  )
  parser.add_argument(
      '--test_batch_size',
      type=int,
      default=-1,
      help="""\
      How many images to test on. This test set is only used once, to evaluate
      the final accuracy of the model after training completes.
      A value of -1 causes the entire test set to be used, which leads to more
      stable results across runs.\
      """
  )
  parser.add_argument(
      '--validation_batch_size',
      type=int,
      default=100,
      help="""\
      How many images to use in an evaluation batch. This validation set is
      used much more often than the test set, and is an early indicator of how
      accurate the model is during training.
      A value of -1 causes the entire validation set to be used, which leads to
      more stable results across training iterations, but may be slower on large
      training sets.\
      """
  )
  parser.add_argument(
      '--print_misclassified_test_images',
      default=False,
      help="""\
      Whether to print out a list of all misclassified test images.\
      """,
      action='store_true'
  )
  parser.add_argument(
      '--model_dir',
      type=str,
      #default='/tmp/imagenet',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/inception-2015-12-05',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """
  )
  parser.add_argument(
      '--bottleneck_dir',
      type=str,
      #default='/tmp/bottleneck',
      default='C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/bottleneck',
      help='Path to cache bottleneck layer values as files.'
  )
  parser.add_argument(
      '--final_tensor_name',
      type=str,
      default='final_result',
      help="""\
      The name of the output classification layer in the retrained graph.\
      """
  )
  parser.add_argument(
      '--flip_left_right',
      default=False,
      help="""\
      Whether to randomly flip half of the training images horizontally.\
      """,
      action='store_true'
  )
  parser.add_argument(
      '--random_crop',
      type=int,
      default=0,
      help="""\
      A percentage determining how much of a margin to randomly crop off the
      training images.\
      """
  )
  parser.add_argument(
      '--random_scale',
      type=int,
      default=0,
      help="""\
      A percentage determining how much to randomly scale up the size of the
      training images by.\
      """
  )
  parser.add_argument(
      '--random_brightness',
      type=int,
      default=0,
      help="""\
      A percentage determining how much to randomly multiply the training image
      input pixels up or down by.\
      """
  )
  parser.add_argument(
      '--architecture',
      type=str,
      default='inception_v3',
      help="""\
      Which model architecture to use. 'inception_v3' is the most accurate, but
      also the slowest. For faster or smaller models, chose a MobileNet with the
      form 'mobilenet_<parameter size>_<input_size>[_quantized]'. For example,
      'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
      pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
      less accurate, but smaller and faster network that's 920 KB on disk and
      takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
      for more information on Mobilenet.\
      """)

6. 运行

程序首先会载入inception模型,然后将数据集输入生成Bottleneck文件存放在本地。然后再使用Bottleneck文件训练最后的softmax layer。训练结果如下:


i7-6600U Win10

此图为使用Win10,Anaconda,Python 3.5 CPU训练,大概耗时10~20分钟吧,忘记在代码中加入计时了。最后几轮训练输出如下,最终精度为91.8%:

INFO:tensorflow:2017-10-13 13:30:55.851475: Step 3980: Train accuracy = 97.0%
INFO:tensorflow:2017-10-13 13:30:55.852477: Step 3980: Cross entropy = 0.105862
INFO:tensorflow:2017-10-13 13:30:55.973800: Step 3980: Validation accuracy = 93.0% (N=100)
INFO:tensorflow:2017-10-13 13:30:57.164969: Step 3990: Train accuracy = 99.0%
INFO:tensorflow:2017-10-13 13:30:57.165970: Step 3990: Cross entropy = 0.083922
INFO:tensorflow:2017-10-13 13:30:57.287292: Step 3990: Validation accuracy = 85.0% (N=100)
INFO:tensorflow:2017-10-13 13:30:58.387826: Step 3999: Train accuracy = 94.0%
INFO:tensorflow:2017-10-13 13:30:58.388830: Step 3999: Cross entropy = 0.175846
INFO:tensorflow:2017-10-13 13:30:58.514203: Step 3999: Validation accuracy = 90.0% (N=100)
INFO:tensorflow:Final test accuracy = 91.8% (N=732)
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
An exception has occurred, use %tb to see the full traceback.

SystemExit

C:\Users\Dexter\Anaconda2\envs\TensorFlow_Py35\lib\site-packages\IPython\core\interactiveshell.py:2870: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

注意程序在最后训练结束后会将softmax layer(包含一个全连接层)的两个参数固化,结合剩余的inception模型参数,一起固化为新的模型。
后来又使用1080ti+Ubuntu重新跑了一遍,大概5分钟足矣,不过运气不好,精度只有88.0%。

文件列表
程序将产生bottleneck文件,retrain_logs,output_labels和最重要的output_graph.pb文件。后续使用训练好的模型来预测,就需要载入output_graph.pb文件,相关代码与TensorFlow学习笔记:使用Inception v3进行图像分类类似。

7. 测试

7.1 下载数据

从网上随便下载一些图片,特地挑选了在5个类别中,但是又很难分辨出的图片:


daisy.jpg
roses.jpg

第一个感觉杂糅了雏菊和向日葵,第二个玫瑰还是郁金香真是傻傻分不清。总共准备了20张图片:


test_images

7.2 测试脚本

主要参考TensorFlow学习笔记:使用Inception v3进行图像分类中的代码:

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 13 16:15:16 2017
use_output_graph
使用retrain所训练的迁移后的inception模型来测试
@author: Dexter
"""
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt

model_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3'
model_name = 'output_graph.pb'
image_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images'
label_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/'
label_filename = 'output_labels.txt'

# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定义一个空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

# 读取标签labels
def load_labels(label_file_dir):
    if not tf.gfile.Exists(label_file_dir):
        # 预先检测地址是否存在
        tf.logging.fatal('File does not exist %s', label_file_dir)
    else:
        # 读取所有的标签返并回一个list
        labels = tf.gfile.GFile(label_file_dir).readlines()
        for i in range(len(labels)):
            labels[i] = labels[i].strip('\n')
    return labels

# 创建graph
create_graph()

# 创建会话,因为是从已有的Inception_v3模型中恢复,所以无需初始化
with tf.Session() as sess:
    # Inception_v3模型的最后一层final_result:0的输出
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    
    # 遍历目录
    for root, dirs, files in os.walk(image_dir):
        for file in files:
            # 载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
            # 输入图像(jpg格式)数据,得到softmax概率值(一个shape=(1,1008)的向量)
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
            # 将结果转为1维数据
            predictions = np.squeeze(predictions)
    
            # 打印图片路径及名称
            image_path = os.path.join(root, file)
            print(image_path)
            # 显示图片
            img = Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()
            
            # 排序,取出前5个概率最大的值(top-5),本数据集一共就5个
            # argsort()返回的是数组值从小到大排列所对应的索引值
            top_5 = predictions.argsort()[-5:][::-1]
            for label_index in top_5:
                # 获取分类名称
                label_name = load_labels(os.path.join(
                        label_dir, label_filename))[label_index]
                # 获取该分类的置信度
                label_score = predictions[label_index]
                print('%s (score = %.5f)' % (label_name, label_score))
            print()

7.3 输出

部分结果如下:

C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images\18.jpg
roses (score = 0.99892)
tulips (score = 0.00065)
sunflowers (score = 0.00032)
dandelion (score = 0.00007)
daisy (score = 0.00003)

C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3/use_output_graph/test_images\19.jpg
roses (score = 0.99359)
tulips (score = 0.00633)
sunflowers (score = 0.00008)
daisy (score = 0.00000)
dandelion (score = 0.00000)

当然,并不是所有的图都预测的非常准确,有几张图,就差点翻车:


sunflowers

roses

不过这个扭曲的雏菊倒是毫无压力:


daisy

7.4 label_image.py

发现Google同样提供了测试脚本:
tensorflow / tensorflow /examples / image_retraining / label_image.py
不过这个脚本只能预测一张照片,感觉不如自己写的,使用时修改命令行参数的默认值即可,源码(未修改)如下,仅用做Mark:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple image classification with Inception.

Run image classification with your model.

This script is usually used with retrain.py found in this same
directory.

This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. You are required
to pass in the graph file and the txt file.

It outputs human readable strings of the top 5 predictions along with
their probabilities.

Change the --image_file argument to any jpg image to compute a
classification of that image.

Example usage:
python label_image.py --graph=retrained_graph.pb
  --labels=retrained_labels.txt
  --image=flower_photos/daisy/54377391_15648e8d18.jpg

NOTE: To learn to use this file and retrain.py, please see:

https://codelabs.developers.google.com/codelabs/tensorflow-for-poets
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import tensorflow as tf

parser = argparse.ArgumentParser()
parser.add_argument(
    '--image', required=True, type=str, help='Absolute path to image file.')
parser.add_argument(
    '--num_top_predictions',
    type=int,
    default=5,
    help='Display this many predictions.')
parser.add_argument(
    '--graph',
    required=True,
    type=str,
    help='Absolute path to graph file (.pb)')
parser.add_argument(
    '--labels',
    required=True,
    type=str,
    help='Absolute path to labels file (.txt)')
parser.add_argument(
    '--output_layer',
    type=str,
    default='final_result:0',
    help='Name of the result operation')
parser.add_argument(
    '--input_layer',
    type=str,
    default='DecodeJpeg/contents:0',
    help='Name of the input operation')


def load_image(filename):
  """Read in the image_data to be classified."""
  return tf.gfile.FastGFile(filename, 'rb').read()


def load_labels(filename):
  """Read in labels, one label per line."""
  return [line.rstrip() for line in tf.gfile.GFile(filename)]


def load_graph(filename):
  """Unpersists graph from file as default graph."""
  with tf.gfile.FastGFile(filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


def run_graph(image_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  with tf.Session() as sess:
    # Feed the image_data as input to the graph.
    #   predictions will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0


def main(argv):
  """Runs inference on an image."""
  if argv[1:]:
    raise ValueError('Unused Command Line Args: %s' % argv[1:])

  if not tf.gfile.Exists(FLAGS.image):
    tf.logging.fatal('image file does not exist %s', FLAGS.image)

  if not tf.gfile.Exists(FLAGS.labels):
    tf.logging.fatal('labels file does not exist %s', FLAGS.labels)

  if not tf.gfile.Exists(FLAGS.graph):
    tf.logging.fatal('graph file does not exist %s', FLAGS.graph)

  # load image
  image_data = load_image(FLAGS.image)

  # load labels
  labels = load_labels(FLAGS.labels)

  # load graph, which is stored in the default session
  load_graph(FLAGS.graph)

  run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer,
            FLAGS.num_top_predictions)


if __name__ == '__main__':
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=sys.argv[:1]+unparsed)

8. TensorBoard 可视化

打开命令行,输入:

tensorboard --logdir /tmp/retrain_logs
# 将/tmp/retrain_logs改为实际存储目录,如果之前有修改命令行参数默认值


然而,在浏览器中输入http://Dexter:6006并不能打开:

解决方法:输入localhost:6006即可。
tensorboard


参考资料:

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 210,914评论 6 490
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 89,935评论 2 383
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 156,531评论 0 345
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,309评论 1 282
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,381评论 5 384
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,730评论 1 289
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,882评论 3 404
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,643评论 0 266
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,095评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,448评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,566评论 1 339
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,253评论 4 328
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,829评论 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,715评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,945评论 1 264
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,248评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,440评论 2 348

推荐阅读更多精彩内容