我开始使用Dataset API替换feed_dict系统。
但是,在创建数据集管道之后,如何在不使用feed_dict的情况下将数据集的数据提供给模型?
首先,我创建了一个单发迭代器。但是在这种情况下,您需要使用feed_dict将来自迭代器的数据提供给模型。
第二,我尝试直接从tf.placeholder创建我的数据集,然后使用initializable_iterator。但是在这里,我不明白如何摆脱feed_dict。另外,我不理解这种基于Plaeholders的数据集的目的是什么。
我的基本模型:
x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()
我的数据:
np_data = np.random.sample((100,2))
方法1:
dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
方法2:
dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
sess.run(iterator.initializer, feed_dict={x: np_data})
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
https://www.tensorflow.org/guide/performance/overview#input_pipeline
那么,我该如何“避免对所有琐碎的示例都使用feed_dict”? 我想我不理解数据集API的概念
答案 0 :(得分:2)
是的,如果使用数据集api,则无需使用feed_dict
。
相反,我们每次只能将密集层应用于next_value
。
类似这样的东西:
def model(x):
dense = tf.layers.dense(x, 1)
return dense
result_for_this_iteration = model(next_value)
因此,您的完整玩具示例可能看起来像这样:
def model(x):
dense = tf.layers.dense(x, 10)
return dense
dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
result_for_this_iteration = model(next_value)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while(True):
try:
result = sess.run(result_for_this_iteration)
print (result)
except OutOfRangeError:
print ("no more data")
当然,还有许多其他配置选项。我们可以repeat()
,以便我们不到达数据的末尾,而是循环遍历它。我们可以batch(n)
分成大小为n
的批次。我们可以map(pre_process)
将pre_process
函数应用于每个元素,等等。