在线服务代码简介
本代码基于TensorFlow 1.1 的mnist with summaries例子实现在线推理服务https://github.com/tensorflow/tensorflow/blob/v1.1.0/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py。
完整的推理服务代码位于https://github.com/ucloud/uai-sdk/tree/master/examples/tensorflow/inference/mnist_1.1,推理服务的代码为mnist_inference.py,我们同时提供了conf.json和模型checkpoint_dir
mnist_inference.py
minst_inference.py 实现了load_model和execute两个函数。
创建 MnistModel 类
minst_inference.py首先需要实现一个在线服务的类,该类继承了TFAiUcloudModel(TensorFlow 在线服务基类)
"""A very simple MNIST inferencer. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from PIL import Image import numpy as np import tensorflow as tf from uai.arch.tf_model import TFAiUcloudModel class MnistModel(TFAiUcloudModel): """Mnist example_tf model """ def __init__(self, conf): super(MnistModel, self).__init__(conf)
实现load\_model
load_model实现分为三个部分:
- 创建graph
- 使用tf.train.Saver() 加载模型,模型目录地址可以从self.model_dir获取,该变量由TFAiUcloudModel实现,并在初始化时从conf.json中获取
- 将执行推理所需的sess、x、y_ 三个变量保存到MnistModel.output全局变量中
def load_model(self): sess = tf.Session() """ 1 Define MNIST net y = x * W + b y_ = softmax(y) """ x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b y_ = tf.nn.softmax(y) """ 2 Load model from self.model_dir The default DIR name is checkpoint_dir/, it should include following files: checkpoint: tf checkpoint config file model.mod: model file model.mod.meta: model meta data file """ saver = tf.train.Saver() params_file = tf.train.latest_checkpoint(self.model_dir) saver.restore(sess, params_file) """ 3 Register ops into self.output dict. So func execute() can get these ops """ self.output['sess'] = sess self.output['x'] = x self.output['y_'] = y_
实现execute
实现execute实现分为四个部分:
- 从MnistModel.output全局变量中获取sess、x、y_ 三个变量
- 从data获取batching的请求数据,并转化为numpy list(我们将所有的请求batch成了一个矩阵imgs)
- 请求推理操作:predict_values = sess.run(y_, feed_dict={x: imgs})
- 将请求结果转化成string,并合并成results(results也是一个list,和data list是一一对应的关系)
def execute(self, data, batch_size): """ 1 """ sess = self.output['sess'] x = self.output['x'] y_ = self.output['y_'] """ 2 """ imgs = [] for i in range(batch_size): im = Image.open(data[i]).resize((28, 28)).convert('L') im = np.array(im) im = im.reshape(784) im = im.astype(np.float32) im = np.multiply(im, 1.0 / 255.0) imgs.append(im) """ 3 """ imgs = np.array(imgs) predict_values = sess.run(y_, feed_dict={x: imgs}) print(predict_values) """ 4 """ ret = [] for val in predict_values: ret_val = np.array_str(np.argmax(val)) + '\n' ret.append(ret_val) return ret