我想使用使用import tensorflow as tf
def make_batch_generator_fn(batch_size=10, dset_size=100):
feats, targs = range(dset_size), range(1, dset_size + 1)
def batch_generator_fn():
start_idx, stop_idx = 0, batch_size
while True:
# if stop_idx > dset_size: --- stop action?
yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size
return batch_generator_fn
def test(batch_size=10):
dgen = make_batch_generator_fn(batch_size)
features_shape, targets_shape = [None], [None]
ds = tf.data.Dataset.from_generator(
dgen, (tf.int32, tf.int32),
(tf.TensorShape(features_shape), tf.TensorShape(targets_shape))
)
feats, targs = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
counter = 0
try:
while True:
f, t = sess.run([feats, targs])
print(f, t)
counter += 1
if counter > 15:
break
except tf.errors.OutOfRangeError:
print('end of dataset at counter = {}'.format(counter))
if __name__ == '__main__':
test()
构建的TensorFlow数据集来访问格式化文件。除了我不知道如何在生成器数据耗尽时停止数据集迭代器(当你超出范围时,生成器只会永远返回空列表),大多数都可以工作。
我的实际代码非常复杂,但我可以通过这个简短的程序来模拟这种情况:
stop action?
如果我事先知道记录的数量,我可以调整批次的数量,但我不会总是知道。我已经尝试在上面的代码段中添加一些代码,我在其中有IndexError
这样的注释行。特别是,我尝试过提升catch
,但TensorFlow并不喜欢这样,即使我在执行代码中明确tf.errors.OutOfRangeError
。我也试过提出sort((x,y) => x.description < y.description ? -1 : 1)
,但我不确定如何实例化它:构造函数需要三个参数 - &#39; node_def&#39;,&#39; op&#39;和& #39;消息&#39;,我不太确定要使用什么节点&#39; node_def&#39;和&#39; op&#39;总的来说。
我对此问题的任何想法或意见表示感谢。谢谢!
答案 0 :(得分:0)
它适用于以下几行:
dataset_size = your dataset size
batch_size = your batch size
dataset = your tf.data.Dataset
steps_per_epoch = dataset_size // batch_size
for data, _ in zip(dataset, range(steps_per_epoch)):
# your train_step
迭代将结束。