Java调用Tensorflow-pb模型进行预测

Java调用Tensorflow-pb模型进行预测

​ 首先来看下官网给的例子,把环境先整明白, tf-java, 基本步骤是先创建maven工程,然后配置pom,再reimport,然后运行下demo

  • pom配置

    仿照官网给的信息,在pom中加入tf信息,这里使用的tf版本为1.12.0

    pom.png
  • HelloTensorFlow

    输入官方样例代码后,reimport下工程,运行代码,成功了可以看到输出TF的版本信息

    1.png

​ 上面给出的代码只是把环境处理明白,并未涉及到加载pb文件进行预测,这里主要参考tensorflow-serving-tutorial,整个流程写的很明白,所以照着来

​ 使用以下命令查看下模型的输入输出,用于后面java代码里指定节点的名字(如下图:看网上说的使用自定义的名字'images'即可,但是我在实验的过程中并不行,使用的是name节点 x:0 才成功的)

saved_model_cli show --dir ./logs --all
# --dir 指定保存pb文件的路径,这里即为 ./logs
input-output.png
  • Inference

    使用mnist_input_data.py里的代码读取测试集数据,这里只取第一行数据,label为7,运行结果如下:

    predict.png

MnistPredict.java的完整代码如下

package SimpleTest;

import java.util.Arrays;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Graph;
import org.tensorflow.Tensor;

public class MnistPredict {
    SavedModelBundle tensorflowModelBundle;
    Session tensorflowSession;

    void load(String modelPath){
        this.tensorflowModelBundle = SavedModelBundle.load(modelPath, "serve");
        this.tensorflowSession = tensorflowModelBundle.session();
    }

    public Tensor predict(Tensor tensorInput){
        // feed()传参类似python端的feed_dict,  fetch()指定输出节点的名称
        // 输出可能会有多个节点,其类型list,这里只输出一个节点,所以 get(0)
        Tensor output = this.tensorflowSession.runner().feed("x:0", tensorInput).fetch("y:0").run().get(0);
        return output;
    }

    public static void main(String[] args){
        // 创建输入tensor, 注意type、shape应和训练时一致
        float[][] testvec = {{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,84,185,159,151,60,36,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,222,254,254,254,254,241,198,198,198,198,198,198,198,198,170,52,0,0,0,0,0,0,
                0,0,0,0,0,0,67,114,72,114,163,227,254,225,254,254,254,250,229,254,254,140,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,17,66,14,67,67,67,59,21,236,254,106,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,83,253,209,18,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,22,233,255,83,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,129,254,238,44,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,59,249,254,62,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,133,254,187,5,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9,205,248,58,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,126,254,182,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,75,251,240,57,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,19,221,254,166,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,3,203,254,219,35,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,38,254,254,77,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,31,224,254,115,1,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,133,254,254,52,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,61,242,254,254,52,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,121,254,254,219,40,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,121,254,207,18,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}};
        Tensor input = Tensor.create(testvec);

        // load 模型
        MnistPredict myModel = new MnistPredict();
        String modelPath = "models/";
        myModel.load(modelPath);

        // 模型推理,注意resultValues的type、shape
        Tensor out = myModel.predict(input);
        float[][] resultValues = (float[][]) out.copyTo(new float[1][10]);

        input.close();      // 防止内存泄漏,释放tensor内存
        out.close();
        System.out.println(Arrays.toString(resultValues[0]));
    }
}
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。