tensorflow的返回大小的数据集API不是常量

时间:2017-10-12 11:48:26

标签: python tensorflow iterator bigdata tensorflow-datasets

我正在使用tensorflow' dataset API。用简单的案例测试我的代码。下面显示了我使用的简单代码。问题是,当数据集大小很小时,似乎数据集API返回的大小不一致。我确定有一个正确的方法来处理它。但即使我阅读了该页面和教程中的所有功能,我也找不到。

import numpy as np
import tensorflow as tf

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(16)
dataset = dataset.repeat()

iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(dataset)

with tf.Session() as sess:
    sess.run(training_init_op)
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))

数据集是灰度视频。共有24个视频序列,步长均为200.帧大小为64乘64和单通道。我将批量大小设置为16,缓冲区大小设置为100.但代码的结果是,

(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

返回的视频大小是16或8.我想这是因为原始数据量很小,24,当它到达数据末尾时,API只返回剩下的内容。

但我不明白。我还将缓冲区大小设置为100.这意味着应该使用小数据集预先填充缓冲区。从该缓冲区开始,API应该选择next_element,其批量大小为16。

当我在张量流中使用队列类型API时,我没有遇到这个问题。无论原始数据的大小是多少,无论如何都有迭代器到达数据集末尾的时刻。我想知道使用这个API的其他人如何解决这个问题。

2 个答案:

答案 0 :(得分:6)

尝试在repeat()之前致电batch()

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.repeat()
dataset = dataset.batch(16)

我得到的结果:

(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

答案 1 :(得分:0)

您可以使用以下代码解决问题:

batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))