我正在使用tensorflow' dataset API。用简单的案例测试我的代码。下面显示了我使用的简单代码。问题是,当数据集大小很小时,似乎数据集API返回的大小不一致。我确定有一个正确的方法来处理它。但即使我阅读了该页面和教程中的所有功能,我也找不到。
import numpy as np
import tensorflow as tf
data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(16)
dataset = dataset.repeat()
iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(dataset)
with tf.Session() as sess:
sess.run(training_init_op)
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
next_elem = next_element.eval()
print(np.shape(next_elem))
数据集是灰度视频。共有24个视频序列,步长均为200.帧大小为64乘64和单通道。我将批量大小设置为16,缓冲区大小设置为100.但代码的结果是,
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
返回的视频大小是16或8.我想这是因为原始数据量很小,24,当它到达数据末尾时,API只返回剩下的内容。
但我不明白。我还将缓冲区大小设置为100.这意味着应该使用小数据集预先填充缓冲区。从该缓冲区开始,API应该选择next_element,其批量大小为16。
当我在张量流中使用队列类型API时,我没有遇到这个问题。无论原始数据的大小是多少,无论如何都有迭代器到达数据集末尾的时刻。我想知道使用这个API的其他人如何解决这个问题。
答案 0 :(得分:6)
尝试在repeat()
之前致电batch()
:
data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.repeat()
dataset = dataset.batch(16)
我得到的结果:
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
答案 1 :(得分:0)
您可以使用以下代码解决问题:
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))