TensorFlow-4: tf.contrib.learn 快速入门

学习资料:
https://www.tensorflow.org/get_started/tflearn

相应的中文翻译:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/


今天学习用 tf.contrib.learn 来建立 DNN 对 Iris 数据集进行分类.

问题:
我们有 Iris 数据集,它包含150个样本数据,分别来自三个品种,每个品种有50个样本,每个样本具有四个特征,以及它属于哪一类,分别由 0,1,2 代表三个品种。
我们将这150个样本分为两份,一份是训练集具有120个样本,另一份是测试集具有30个样本。
我们要做的就是建立一个神经网络分类模型对每个样本进行分类,识别它是哪个品种。

一共有 5 步:

  • 导入 CSV 格式的数据集
  • 建立神经网络分类模型
  • 用训练数据集训练模型
  • 评价模型的准确率
  • 对新样本数据进行分类

代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import urllib

import numpy as np
import tensorflow as tf

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

def main():
  # If the training and test sets aren't stored locally, download them.
  if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING, "w") as f:
      f.write(raw)

  if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST, "w") as f:
      f.write(raw)

  # Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                              hidden_units=[10, 20, 10],
                                              n_classes=3,
                                              model_dir="/tmp/iris_model")
  # Define the training inputs
  def get_train_inputs():
    x = tf.constant(training_set.data)
    y = tf.constant(training_set.target)

    return x, y

  # Fit model.
  classifier.fit(input_fn=get_train_inputs, steps=2000)

  # Define the test inputs
  def get_test_inputs():
    x = tf.constant(test_set.data)
    y = tf.constant(test_set.target)

    return x, y

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
                                       steps=1)["accuracy"]

  print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

  # Classify two new flower samples.
  def new_samples():
    return np.array(
      [[6.4, 3.2, 4.5, 1.5],
       [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)

  predictions = list(classifier.predict(input_fn=new_samples))

  print(
      "New Samples, Class Predictions:    {}\n"
      .format(predictions))

if __name__ == "__main__":
    main()

从代码可以看出很简短的几行就可以完成之前学过的很长的代码所做的事情,用起来和用 sklearn 相似。

关于 tf.contrib.learn 可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn

可以看到里面也有 kmeans,logistic,linear 等模型:


在上面的代码中:

  • tf.contrib.learn.datasets.base.load_csv_with_header 可以导入 CSV 数据集。
  • 分类器模型只需要一行代码,就可以设置这个模型具有多少隐藏层,每个隐藏层有多少神经元,以及最后分为几类。
  • 模型的训练也是只需要一行代码,输入指定的数据,包括特征和标签,再指定迭代的次数,就可以进行训练。
  • 获得准确率也同样很简单,只需要输入测试集,调用 evaluate。
  • 预测新的数据集,只需要把新的样本数据传递给 predict。

关于代码里几个新的方法:

1. load_csv_with_header():

用于导入 CSV,需要三个必需的参数:

  • filename,CSV文件的路径
  • target_dtype,数据集的目标值的numpy数据类型。
  • features_dtype,数据集的特征值的numpy数据类型。

在这里,target 是花的品种,它是一个从 0-2 的整数,所以对应的numpy数据类型是np.int

2. tf.contrib.layers.real_valued_column:

所有的特征数据都是连续的,因此用 tf.contrib.layers.real_valued_column,数据集中有四个特征(萼片宽度,萼片高度,花瓣宽度和花瓣高度),因此 dimension=4 。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

3. DNNClassifier:

  • feature_columns=feature_columns, 上面定义的一组特征
  • hidden_units=[10, 20, 10],三个隐藏层分别包含10,20,10个神经元。
  • n_classes=3,三个目标类,代表三个 Iris 品种。
  • model_dir=/tmp/iris_model,TensorFlow在模型训练期间将保存 checkpoint data。

在后面会学到关于 TensorFlow 的 logging and monitoring 的章节,可以 track 一下训练中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。


推荐阅读 历史技术博文链接汇总
//www.greatytc.com/p/28f02bb59fe5
也许可以找到你想要的

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 197,966评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,170评论 2 375
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 144,909评论 0 327
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,959评论 1 268
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,851评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,583评论 1 275
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,956评论 3 388
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,590评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,878评论 1 293
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,892评论 2 314
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,719评论 1 328
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,501评论 3 316
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,957评论 3 300
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,124评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,440评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,003评论 2 343
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,211评论 2 339

推荐阅读更多精彩内容