我正在尝试使用tensorflow数据集api创建预测脚本。以前我使用低级API和feed_dict执行此操作:
#import graph
saver = tf.train.import_meta_graph('...')
# Select variables to feed
x = graph.get_tensor_by_name("X:0")
predictions = graph.get_tensor_by_name("pred:0")
with tf.Session() as sess:
p = sess.run(predictions, feed_dict={x:x_feed})
现在我以下面的方式使用数据集API:
iterator =
tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
我正在保存.meta和.data文件。 如何使用这些来创建预测脚本?我无法从图表中提取操作并输入所需的值,并且没有定义占位符。一种方法是使用相同的脚本并使用测试数据,但必须有更好的方法吗?
由于