Tensorflow Dataset API是否完全摆脱了feed_dict参数?

时间:2019-08-02 08:37:26

标签: python tensorflow tensorflow-datasets

我开始使用Dataset API替换feed_dict系统。

但是,在创建数据集管道之后,如何在不使用feed_dict的情况下将数据集的数据提供给模型?

首先,我创建了一个单发迭代器。但是在这种情况下,您需要使用feed_dict将来自迭代器的数据提供给模型。

第二,我尝试直接从tf.placeholder创建我的数据集,然后使用initializable_iterator。但是在这里,我不明白如何摆脱feed_dict。另外,我不理解这种基于Plaeholders的数据集的目的是什么。

我的基本模型:

x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()

我的数据:

np_data = np.random.sample((100,2))

方法1:

dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

方法2:

dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)
  sess.run(iterator.initializer, feed_dict={x: np_data})

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

https://www.tensorflow.org/guide/performance/overview#input_pipeline

那么,我该如何“避免对所有琐碎的示例都使用feed_dict”? 我想我不理解数据集API的概念

1 个答案:

答案 0 :(得分:2)

是的,如果使用数据集api,则无需使用feed_dict

相反,我们每次只能将密集层应用于next_value

类似这样的东西:

def model(x):
  dense = tf.layers.dense(x, 1)
  return dense

result_for_this_iteration = model(next_value)

因此,您的完整玩具示例可能看起来像这样:

def model(x):
  dense = tf.layers.dense(x, 10)
  return dense

dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

result_for_this_iteration = model(next_value)


with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  while(True):
    try:
      result = sess.run(result_for_this_iteration)
      print (result)
    except OutOfRangeError:
      print ("no more data")

当然,还有许多其他配置选项。我们可以repeat(),以便我们不到达数据的末尾,而是循环遍历它。我们可以batch(n)分成大小为n的批次。我们可以map(pre_process)pre_process函数应用于每个元素,等等。