我是Tensorflow Dataset API的新手,我无法完全理解其设计的简单性,因此我需要一些帮助。
这是一个简单的例子
import tensorflow as tf
x = tf.placeholder(tf.int32, shape=[])
y = tf.square(x)
with tf.Session() as sess:
print(sess.run(y, {x: 2}))
# result is 4, simple
如果我有一个整数数组arr_x=[2, 3, 5, 8, 10]
,如何使用Dateset API迭代该数组?
我正在尝试
p = tf.placeholder(tf.int32, shape=[None])
d = tf.data.Dataset.from_tensor_slices(p)
d = d.map(lambda x: x)
iter = d.make_initializable_iterator()
next_element = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer, feed_dict={p: [2, 3, 4]})
while True:
try:
print sess.run(y, next_element)
except tf.errors.OutOfRangeError:
break
但是没有运气,有什么主意吗?
答案 0 :(得分:1)
那又怎么样:
arr_x = np.array([2, 3, 5, 8, 10])
arr_y = np.array([[0,1],[1,0],[1,0],[0,1],[1,0]])
dataset = tf.data.Dataset.from_tensor_slices((arr_x, arr_y))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
while True:
try:
print(sess.run(next_element))
except tf.errors.OutOfRangeError:
break