我创建了一个Tensorflow模型,该模型使用Dataset API来将数据馈送到网络中。
在训练阶段之后,我想恢复此模型并不时对它进行推断。
当前,我每次都在重新初始化数据集迭代器,但是我想知道是否还有另一种方法。 此外,在训练时,我的数据集包含x和y数据,而在预测时,我只有x。作为一个临时解决方案,我提供了一个伪造的y,但是,这似乎并不是最好的解决方案。
这是我在做什么的伪代码:
#### NETWORK
input_x = tf.placeholder(tf.int32, [None, None], name="input_x")
input_y = tf.placeholder(tf.int32, [None, 2], name="input_y")
dataset = tf.data.Dataset.from_tensor_slices((input_x, input_y))
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
x_data, y_data = iterator.get_next()
output = tf.variable(x_data, name='output')
.....
### INFERENCE
while (true):
x = new_input
x_operation = session.graph.get_operation_by_name("input_x").outputs[0]
y_operation = session.graph.get_operation_by_name("input_y").outputs[0]
dataset_operation = session.graph.get_operation_by_name("dataset_init")
output_operation = session.graph.get_operation_by_name("output").outputs[0]
fake_y = np.array([[0, 0]])
dic = {input_x: x, input_y: y}
session.run(dataset_operation, feed_dict=dic)
prediction = session.run(output_operation)
谢谢您的帮助