TensorFlow v1.10 +:更新tf.estimator.BestExporter并保存结果?

时间:2019-06-19 12:41:34

标签: python tensorflow tensorflow-serving tensorflow-estimator

我有一个自定义估算器,其形式为tf.estimator.BestExporter

exporter = tf.estimator.BestExporter(
    name="best_exporter",
    serving_input_receiver_fn=serving_input_receiver_fn,
    exports_to_keep=5
) # this will keep the 5 best checkpoints

所以在model_dir下,我现在有:

# (inside model_dir/)
...
export/
-   best_exporter/
    -   <timestamp>
        -   variables/
            -   variables.data-00000-of-00001
            -   variables.index
        -   saved_model.pb

我可以通过以下方式加载和使用导出的估算器

predict_fn = predictor.from_saved_model(os.path.join(best_exporter_dir, timestamp))

我希望能够更新此估算器的值(例如某层some_layer/kernel:0的权重)

有一个相关的(但不完全相同)的GitHub issue,它解决了如何使用模型 checkpoints relevant part of issue)做到这一点,这已得到发布者的确认。与TensorFlow v1.4一起使用。

我试图编织代码的相关部分,至少能够更新一些权重:


def load_estimator_graph(export_dir:str)->None:
    '''Solves import issues when using tf.estimator.(Best)Exporter for saving
    models rather than using the last checkpoint.

    Arguments:
        export_dir (str): the full path to exported tf.estimator model
    Returns:
        None
    '''
    with tf.Session(graph=tf.Graph()) as sess:
        meta_graph   = tf.saved_model.loader.load(sess, ['serve'], export_dir)
    with tf.Session() as sess:
        loaded_graph = tf.train.import_meta_graph(meta_graph)

def lazy_fetch_variable_values(variable_names:list)->dict:
    '''
    Notes:
        "lazy" refers to:
            1. the use of `tf.initialize_all_variables()` to ensure
                variables have values
            2. the use of `tf.trainable_variables()` to search the likely
                releveant values

    Arguments:
        variable_names (list): list of variable names (str) to retrieve from the
            default tensorflow graph

    Returns:
        variables (dict): key:value of the variables and the values as pythonic
            data types.
    '''
    init_op = tf.initialize_all_variables()
    variables = {}
    with tf.Session() as sess:
        sess.run(init_op)

        tvars = tf.trainable_variables()
        tvars_vals = sess.run(tvars)

        for var, val in zip(tvars, tvars_vals):
            if var.name in variable_names:
                variables[var.name] = val
    return variables


def lazy_set_variable_values(variables_to_set:dict):
    '''
    Arguments:
        variables_to_set (dict): variable_name, variable_value pairs for which
            to be updated in the graph
    '''
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        tf_global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        for var_to_find, val_to_set in variables_to_set.items():
            var = [v for v in tf_global_vars if v.name == var_to_find][0]
            sess.run(var)

            var = var.assign(val_to_set)
            sess.run(var)

然后类似:


load_estimator_graph(best_exported_model_dir)
layer_name = 'some_layer/kernel:0'
weights = lazy_fetch_variable_values([layer_name])[layer_name]
new_weights = np.copy(weights)
new_weights = 0 # <-- np.ndarray, this sets _element-wise_ all values to 0, 
                #     has same shape as original weight tensor
lazy_set_varriable_values({layer_name: new_weights})


with tf.Session() as session:
    sess.run(tf.initialize_all_variables())
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(best_exported_model_dir, '..', 'best_updated'))

值得注意的是,这是在读取tf.estimator.BestExporter模型中并试图导出到检查点。

因此,如果我尝试恢复检查点:

est = tf.estimator.Estimator(
    model_fn  = model_fn,
    model_dir = os.path.join(best_exported_model_dir, '..', 'best_updated'),
    config    = tf.estimator.RunConfig(**_config['RunConfig']), # same as runtime call
    params    = _config, # same as runtime call

)


eval_fn = lambda : input_fn(mode='eval')

est.evaluate(eval_fn)

我得到:

ValueError Input 0 of layer some_layer is incompatible with the layer: : expected min_ndim=2, found ndim=1. Full shape received: [an_integer]

以上代码中的位置

weights.shape[0] == new_weights.shape[0] == an_integer

问题的根源

理想情况下,我希望以与tf.estimator.BestExportertf.estimator.Estimator.export_savedmodel相同的形式保存更新的模型。

但是,以上导出方法需要estimator和相应的serving_input_receiver_fn的实例。方法predictor.from_saved_model(exported_dir)不会初始化估计器!因此似乎没有一种直接的方法。

注意: -predictor来自from tensorflow.contrib import predictor -我想从导出的模型中导入,更新一些值(例如,偏差/权重),然后以相同的形式导出(不覆盖原始模型)。

0 个答案:

没有答案