Tensorflow 利用高阶API Estimater.predict 实现实时预测,避免reload计算图

Estimater.predict是tensorflow的高阶API,但是在使用中常常会遇到如下情况: 单次预测一个大文件的速度正常,但是想做成接口来实时预测速度却缓慢:因为每次预测都会重新reload一遍计算图。

那么这个问题是否有解呢?答案:yes。可以在Estimater的层面,实现tensorflow Estimater.predict 的实时预测,将计算图只读取一遍后常驻内存(这里吐槽一下,个人感觉这个功能很重要,毕竟工程化接口必须得做到实时预测,但是google好像对这个方法都没有重点强调…)。

最近在改写BERT模型的时候,从bert-as-service中学习了一波(强烈安利一下https://github.com/hanxiao/bert-as-service,感谢大佬无私的开源)。 此外,该帖子https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/rOP4VKcfphg 对这个问题也进行了讨论:

Tensorflow 利用高阶API Estimater.predict 实现实时预测,避免reload计算图_第1张图片
最后得到了解答:
Tensorflow 利用高阶API Estimater.predict 实现实时预测,避免reload计算图_第2张图片

那么这里我就简单的解释一下如何解决:
使用: tf.data.Dataset.from_generator 来读取数据,代替 tf.data.Dataset.from_tensor_slices等其他方法。 然后维护一个生成器,来不停的yield一个成员变量,期间保持生成器开启。每当要实时预测时,修改该成员变量为待预测数据,即可直接输出预测结果。

此外,如果不用tf.data.Dataset.from_generator 来读取数据, 那么要么使用feed_dict的方式(不推荐,这样无法使用Estimater),要么自己构造生成器,见该代码:https://github.com/marcsto/rl/blob/master/src/fast_predict2.py


还是直接贴上代码解释一波吧:
用法:直接将你的Esimator用FastPredict类包装起来,或者也按照相同的思路改写
例如: classifier = FastPredict(learn.Estimator(model_fn=model_params.model_fn, model_dir=model_params.model_dir), my_input_fn)
预测时,调用classifier.predict(feature_batch)即可,在classifier生存期间,整个计算图将常驻内存,每次predict无需reload计算图

"""
    Speeds up estimator.predict by preventing it from reloading the graph on each call to predict.
    加速estimator.predict,防止每次predict时重新加载计算图
    It does this by creating a python generator to keep the predict call open.
    原理:创建一个python生成器,保持predict进程处于一直开启状态
    This version supports tf 1.4 and above and can be used by pre-made Estimators like tf.estimator.DNNClassifier. 
    Author: Marc Stogaitis
 """
import tensorflow as tf


class FastPredict:

    def __init__(self, estimator, input_fn):
        self.estimator = estimator
        self.first_run = True
        self.closed = False
        self.input_fn = input_fn

    def _create_generator(self):
        while not self.closed:
            yield self.next_features

    def predict(self, feature_batch):
        """ Runs a prediction on a set of features. Calling multiple times
            does *not* regenerate the graph which makes predict much faster.
            feature_batch a list of list of features. IMPORTANT: If you're only classifying 1 thing,
            you still need to make it a batch of 1 by wrapping it in a list (i.e. predict([my_feature]), not predict(my_feature) 
        """
        self.next_features = feature_batch
        # 每次输入的batch必须保持一致
        if self.first_run:
            self.batch_size = len(feature_batch)
            self.predictions = self.estimator.predict(
                input_fn=self.input_fn(self._create_generator))
            self.first_run = False
        elif self.batch_size != len(feature_batch):
            raise ValueError("All batches must be of the same size. First-batch:" + str(self.batch_size) + " This-batch:" + str(len(feature_batch)))
	
        results = []
        for _ in range(self.batch_size):
            results.append(next(self.predictions))
        return results

    def close(self):
        self.closed = True
        try:
            next(self.predictions)
        except:
            print("Exception in fast_predict. This is probably OK")


def example_input_fn(generator):
    """ An example input function to pass to predict. It must take a generator as input """

    def _inner_input_fn():
        dataset = tf.data.Dataset().from_generator(generator, output_types=(tf.float32)).batch(1)
        iterator = dataset.make_one_shot_iterator()
        features = iterator.get_next()
        return {'x': features}

    return _inner_input_fn

你可能感兴趣的:(机器学习,深度学习,nlp)