如何使用Dateset馈送数据数组以使用Tensorflow进行推理?

时间:2019-03-24 15:14:36

标签: tensorflow

我是Tensorflow Dataset API的新手,我无法完全理解其设计的简单性,因此我需要一些帮助。

这是一个简单的例子

import tensorflow as tf

x = tf.placeholder(tf.int32, shape=[])
y = tf.square(x)

with tf.Session() as sess:
  print(sess.run(y, {x: 2}))

# result is 4, simple

如果我有一个整数数组arr_x=[2, 3, 5, 8, 10],如何使用Dateset API迭代该数组?

我正在尝试

p = tf.placeholder(tf.int32, shape=[None])
d = tf.data.Dataset.from_tensor_slices(p)
d = d.map(lambda x: x)
iter = d.make_initializable_iterator()
next_element = iter.get_next()

with tf.Session() as sess:
  sess.run(iter.initializer, feed_dict={p: [2, 3, 4]})
  while True:
    try:
      print sess.run(y, next_element)
    except tf.errors.OutOfRangeError:
      break

但是没有运气,有什么主意吗?

1 个答案:

答案 0 :(得分:1)

那又怎么样:

arr_x = np.array([2, 3, 5, 8, 10])
arr_y = np.array([[0,1],[1,0],[1,0],[0,1],[1,0]])
dataset = tf.data.Dataset.from_tensor_slices((arr_x, arr_y))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
while True:
    try:
        print(sess.run(next_element))
    except tf.errors.OutOfRangeError:
        break