Python+Android进行TensorFlow开发

tensorflow

Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。

Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。

将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下:

android目录结构

同时在app的build.gradle中的dependencies模块下添加如下配置:

dependencies {
    ...
    compile files('libs/libandroid_tensorflow_inference_java.jar')
    ...
}

使用tensorflow框架进行机器学习分为四个步骤:

  1. 构造神经网络
  2. 训练神经网络模型
  3. 将训练好的模型输出为pb文件
  4. 在Android上加载pb模型进行计算

前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

# -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""
训练模型
"""
def train():
    # 定义训练数据集batch大小为8
    batch_size = 8

    # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层
    w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
    w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")

    # 定义输入输出格式
    x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
    y_ = tf.placeholder(tf.float32, shape=(None, 1))

    # 定义神经网络前向传播过程
    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2, name="cal_node")

    # 定义交叉熵和反向传播算法
    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)

    # 生成随机训练集
    rdm = RandomState(1)
    dataset_size = 128

    # 定义映射关系
    X = rdm.rand(dataset_size, 2)
    Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

    with tf.Session() as sess:
        # 初始化所有参数
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # print sess.run(w1)
        # print sess.run(w2)

        STEPS = 500
        for i in range(STEPS):
            start = (i * batch_size) % dataset_size
            end = min(start + batch_size, dataset_size)

            # 训练神经网络,更新神经网络参数
            sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})

            if i % 100 == 0:
                total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
                print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))

            print(sess.run(w1))
            print(sess.run(w2))

        # 保存check point
        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(sess, './model/checpt')

上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:

checkpoint相关文件

由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型:

"""
存储pb模型
"""
def dump_graph_to_pb(pb_path):
    with tf.Session() as sess:
        check_point = tf.train.get_checkpoint_state("./model/")
        if check_point:
            saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
            saver.restore(sess, check_point.model_checkpoint_path)
        else:
            raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))

        graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))

        with tf.gfile.GFile(pb_path, "wb") as f:
            f.write(graph_def.SerializeToString())

拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下:

image.png

接下来就可以载入pb,进行计算了:

public class MainActivity extends AppCompatActivity {
    private Graph graph_;
    private Session session_;
    private AssetManager assetManager;

    private static ExecutorService executorService;
    private static Handler handler;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        executorService = Executors.newFixedThreadPool(5);

        // 初始化tensorflow
        initTensorFlow("outmodel.pb");

        // 使用tensorflow进行计算
        runTensorFlow();
    }
    ...
}

通过如下方式载入pb模型,初始化tensorflow:

private boolean initTensorFlow(String modelFile) {
        assetManager = getAssets();
        // 新建Graph
        graph_ = new Graph();

        InputStream is = null;
        try {
            // 读取Assets pb文件
            is = assetManager.open(modelFile);
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }

        try {
            // 加载pb到Graph
            TensorUtil.loadGraph(is, graph_);
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }
        // 初始化session
        session_ = new Session(graph_);
        if (session_ == null) {
            return false;
        }

        return true;
    }

然后就可以使用tensorflow API进行运算了:

private void runTensorFlow() {
        executorService.execute(generatePredictRunnable(handler));
    }

    private Runnable generatePredictRunnable(Handler handler) {
        return new Runnable() {
            @Override
            public void run() {
                float[][] input = new float[1][2];

                input[0][0] = 1;
                input[0][1] = 2;

                // 定义输入tensor
                Tensor inputTensor = Tensor.create(input);

                // 指定输入,输出节点,运行并得到结果
                Tensor resultTensor = session_.runner()
                        .feed("x_input", inputTensor)
                        .fetch("cal_node")
                        .run()
                        .get(0);

                float[][] dst = new float[1][1];
                resultTensor.copyTo(dst);

                // 处理结果
                ArrayList<Float> resultList = new ArrayList<>();
                for (float val : dst[0]) {
                    if (val != 0) {
                        resultList.add(val);
                    } else {
                        break;
                    }
                }
            }
        };
    }

上面就是通过python训练机器学习模型,并在android平台进行调用的完整流程。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,816评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,729评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,300评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,780评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,890评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,084评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,151评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,912评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,355评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,666评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,809评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,504评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,150评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,121评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,628评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,724评论 2 351

推荐阅读更多精彩内容

  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 171,870评论 25 707
  • 作者: 陈迪豪,就职小米科技,深度学习工程师,TensorFlow代码提交者。 TensorFlow深度学习框架 ...
    Kimichen7764阅读 2,567评论 1 47
  • 像风走了八千里 不问归期与你相遇一长大,你好! 清风拂渭水,明月满长安 不知不觉,就在这待了一个月 然而并不能 简...
    长安大学阅读 789评论 1 14
  • 评《认真的老去》 “你在池塘里活得很好,泥鳅很丑但会说喜庆话,癞蛤蟆马虎但很有趣,田螺是个温柔的自闭症,小鲫鱼是你...
    你好啊艾伦郭阅读 693评论 0 2
  • 橙子小姐从来没想到自己有一天会落到如斯境地,她一直以为在分手后还留恋前任是一件非常丢脸并且愚蠢的事情。可是刚刚,就...
    橙子那么酸阅读 290评论 5 2