如何使用tf.saved_model API恢复tf.data.Dataset()中悬空的tf.py_func?

时间:2019-04-16 15:08:10

标签: python-3.x tensorflow tensorflow-serving tensorflow-datasets

在徒劳地使用save_model API进行了恢复tf.py_func()的研究后,我找不到除tensorflow中记录的信息以外的其他信息:

  

该操作必须在与调用tf.py_func()的Python程序相同的地址空间中运行。如果使用分布式TensorFlow,则必须在与调用tf.train.Server的程序相同的过程中运行tf.py_func(),并且必须将创建的操作固定到该服务器上的设备(例如,与{{1 }}:)

两个保存/加载片段有助于说明这种情况。

保存零件:

tf.device()

加载零件:

def wrapper(x, y):
    with tf.name_scope('wrapper'):
        return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])

def Copy(x, y):
    return x, y

x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')

with tf.name_scope('input'):
    ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
    ds = ds.map(wrapper)
    ds = ds.batch(1)
    it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
    it_init_op = it.make_initializer(ds, name='it_init_op')

x_it, y_it = it.get_next()

# Simple operation
with tf.name_scope('add'):
    res = tf.add(x_it, y_it)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
    sess.run([res])
    tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})

错误:

  

ValueError:未找到回调pyfunc_0

众所周知,graph = tf.Graph() graph.as_default() with tf.Session(graph=graph) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test') res = graph.get_tensor_by_name('add/Add:0') it_init_op = graph.get_operation_by_name('input/it_init_op') x_ph = graph.get_tensor_by_name('x_ph:0') y_ph = graph.get_tensor_by_name('y_ph:0') sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]}) for _ in range(5): sess.run([res]) 包装的函数未随模型一起保存。是否有人有解决方案,可以使用tf doc给出的小提示(应用tf.py_func()

来还原此问题)

1 个答案:

答案 0 :(得分:0)

只要没有答案,我会建议我的,其轮廓为pb而不是解决它。苦苦挣扎了很长时间,我最终通过修剪忽略了它。然后,以一种更简单的方式将占位符嫁接到新的输入/输出上。此外,此 py_func在TF2.0中已弃用