到目前为止,我在tensorflow中使用了这样的管道:
queue_filenames = tf.train.string_input_producer(data)
reader = tf.FixedLengthRecordReader(record_bytes=4*4)
class Record(object):
pass
result = Record()
result.ley, value = reader.read(queue_filenames)
record = tf.decode_raw(value, tf.float32)
image = tf.reshape(tf.strided_slice(record,[0],[1]),[1])
label = tf.reshape(tf.strided_slice(record,[1],[4]),[3])
x, y = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
capacity=batch_size*3,
min_after_dequeue=batch_size*2)
但现在我想改为“数据集” - 东西。我写了这个:
dataset = tf.data.FixedLengthRecordDataset(filenames=data,
record_bytes=4*4)
dataset.map(_generate_x_y)
dataset.shuffle(buffer_size=batch_size*2)
dataset.batch(batch_size=batch_size)
dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x, y = iterator.get_next()
使用:
def _generate_x_y(sample):
features = {"x": tf.FixedLenFeature([1], tf.float32),
"y": tf.FixedLenFeature([3], tf.float32)}
parsed_features = tf.parse_single_example(sample,features)
return parsed_features["x"], parsed_features["y"]
我的图表如下:
y_ = network(x)
和
loss = tf.losses.softmax_cross_entropy(y,y_)
train_step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss=loss)
我的会议是:
with tf.Session(graph=graph_train) as sess:
tf.global_variables_initializer().run()
for i in range(100):
_, = sess.run([train_step])
它适用于旧管道,但使用新数据集时,我收到以下错误:
File "C:/***/main.py", line 49, in <module>
x, y = iterator.get_next()
File "C:\***\python\framework\ops.py", line 396, in __iter__
"`Tensor` objects are not iterable when eager execution is not "
TypeError: `Tensor` objects are not iterable when eager execution is not enabled. To iterate over this tensor use `tf.map_fn`.
感谢您的帮助: - )
答案 0 :(得分:1)
可能导致问题的一个明显问题是您没有使用转换后的数据集。基本上,而不是
dataset = tf.data.FixedLengthRecordDataset(filenames=data,
record_bytes=4*4)
dataset.map(_generate_x_y)
dataset.shuffle(buffer_size=batch_size*2)
你应该这样做:
dataset = tf.data.FixedLengthRecordDataset(filenames=data,
record_bytes=4*4)
dataset = dataset.map(_generate_x_y)
dataset = dataset.shuffle(buffer_size=batch_size*2)
每个数据集操作都返回一个新的转换后的数据集。原始对象不会被map
和shuffle
等操作修改。