Java调用Tensorflow-pb模型进行预测
首先来看下官网给的例子,把环境先整明白, tf-java, 基本步骤是先创建maven工程,然后配置pom,再reimport,然后运行下demo
-
pom配置
仿照官网给的信息,在pom中加入tf信息,这里使用的tf版本为1.12.0
-
HelloTensorFlow
输入官方样例代码后,reimport下工程,运行代码,成功了可以看到输出TF的版本信息
上面给出的代码只是把环境处理明白,并未涉及到加载pb文件进行预测,这里主要参考tensorflow-serving-tutorial,整个流程写的很明白,所以照着来
-
train-model
这里首先进行模型的训练,模型选择mnist,直接看官方使用docker部署时给出的样例 tensorflow_serving ,这里把github下的 mnist_saved_model.py 和 mnist_input_data.py 两个文件拷贝下来,改下模型保存路径即可,运行得到保存的pb文件
使用以下命令查看下模型的输入输出,用于后面java代码里指定节点的名字(如下图:看网上说的使用自定义的名字'images'即可,但是我在实验的过程中并不行,使用的是name节点 x:0 才成功的)
saved_model_cli show --dir ./logs --all
# --dir 指定保存pb文件的路径,这里即为 ./logs
-
Inference
使用mnist_input_data.py里的代码读取测试集数据,这里只取第一行数据,label为7,运行结果如下:
MnistPredict.java的完整代码如下
package SimpleTest;
import java.util.Arrays;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Graph;
import org.tensorflow.Tensor;
public class MnistPredict {
SavedModelBundle tensorflowModelBundle;
Session tensorflowSession;
void load(String modelPath){
this.tensorflowModelBundle = SavedModelBundle.load(modelPath, "serve");
this.tensorflowSession = tensorflowModelBundle.session();
}
public Tensor predict(Tensor tensorInput){
// feed()传参类似python端的feed_dict, fetch()指定输出节点的名称
// 输出可能会有多个节点,其类型list,这里只输出一个节点,所以 get(0)
Tensor output = this.tensorflowSession.runner().feed("x:0", tensorInput).fetch("y:0").run().get(0);
return output;
}
public static void main(String[] args){
// 创建输入tensor, 注意type、shape应和训练时一致
float[][] testvec = {{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,84,185,159,151,60,36,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,222,254,254,254,254,241,198,198,198,198,198,198,198,198,170,52,0,0,0,0,0,0,
0,0,0,0,0,0,67,114,72,114,163,227,254,225,254,254,254,250,229,254,254,140,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,17,66,14,67,67,67,59,21,236,254,106,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,83,253,209,18,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,22,233,255,83,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,129,254,238,44,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,59,249,254,62,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,133,254,187,5,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9,205,248,58,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,126,254,182,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,75,251,240,57,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,19,221,254,166,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,3,203,254,219,35,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,38,254,254,77,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,31,224,254,115,1,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,133,254,254,52,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,61,242,254,254,52,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,121,254,254,219,40,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,121,254,207,18,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}};
Tensor input = Tensor.create(testvec);
// load 模型
MnistPredict myModel = new MnistPredict();
String modelPath = "models/";
myModel.load(modelPath);
// 模型推理,注意resultValues的type、shape
Tensor out = myModel.predict(input);
float[][] resultValues = (float[][]) out.copyTo(new float[1][10]);
input.close(); // 防止内存泄漏,释放tensor内存
out.close();
System.out.println(Arrays.toString(resultValues[0]));
}
}