我想知道是否有一些方法可以在Tensorflow中生成批量生成约束。例如,假设我们正在一个庞大的数据集上训练CNN来进行图像分类。是否有可能强制Tensorflow生成所有样本属于同一类的批次?比如,一批图像都用“Apple”标记,另一组图像都标有“橙色”。
我问这个问题的原因是我想做一些实验,看看不同级别的改组如何影响最终训练的模型。通常的做法是为CNN培训进行样本级改组,每个人都在这样做。我只是想亲自检查一下,从而获得更生动,更直接的知识。
谢谢!
答案 0 :(得分:2)
Dataset.filter()
:
labels = np.random.randint(0, 10, (10000))
data = np.random.uniform(size=(10000, 5))
ds = tf.data.Dataset.from_tensor_slices((data, labels))
ds = ds.filter(lambda data, labels: tf.equal(labels, 1)) #comment this line out for unfiltered case
ds = ds.batch(5)
iterator = ds.make_one_shot_iterator()
vals = iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
py_data, py_labels = sess.run(vals)
print(py_labels)
ds.filter()
:
> [1 1 1 1 1]
[1 1 1 1 1]
[1 1 1 1 1]
[1 1 1 1 1]
[1 1 1 1 1]
没有ds.filter()
:
> [8 0 7 6 3]
[2 4 7 6 1]
[1 8 5 5 5]
[7 1 7 4 0]
[7 1 8 0 0]
修改。以下代码显示如何使用可馈送迭代器动态执行批处理标签选择。请参阅" Creating an iterator"
labels = ['Apple'] * 100 + ['Orange'] * 100
data = list(range(200))
random.shuffle(labels)
batch_size = 4
ds_apple = tf.data.Dataset.from_tensor_slices((data, labels)).filter(
lambda data, label: tf.equal(label, 'Apple')).batch(batch_size)
ds_orange = tf.data.Dataset.from_tensor_slices((data, labels)).filter(
lambda data, label: tf.equal(label, 'Orange')).batch(batch_size)
handle = tf.placeholder(tf.string, [])
iterator = tf.data.Iterator.from_string_handle(
handle, ds_apple.output_types, ds_apple.output_shapes)
batch = iterator.get_next()
apple_iterator = ds_apple.make_one_shot_iterator()
orange_iterator = ds_orange.make_one_shot_iterator()
with tf.Session() as sess:
apple_handle = sess.run(apple_iterator.string_handle())
orange_handle = sess.run(orange_iterator.string_handle())
# loop and switch back and forth between apples and oranges
for _ in range(3):
feed_dict = {handle: apple_handle}
print(sess.run(batch, feed_dict=feed_dict))
feed_dict = {handle: orange_handle}
print(sess.run(batch, feed_dict=feed_dict))
典型输出如下。请注意,data
值在Apple和Orange批次中单调增加,表明迭代器未重置。
> (array([2, 3, 6, 7], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
(array([0, 1, 4, 5], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))
(array([ 9, 13, 15, 19], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
(array([ 8, 10, 11, 12], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))
(array([21, 22, 23, 25], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
(array([14, 16, 17, 18], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))