Tensorflow concat tf.data.Dataset批次

时间:2018-04-23 12:50:07

标签: python tensorflow

使用tf.data.Dataset时可以连接 批量数据集的方式不是第二个数据集 在第一个结尾处连接,但是这样 第一批第二个数据集后连接 第一批第二个数据集,依此类推。

我尝试了如下,但这给了我一个长度为40的数据集, 但是,我希望这里的长度为80.

train_data = train_data.batch(40).concatenate(augmentation_data.batch(40))

1 个答案:

答案 0 :(得分:3)

不完全确定您的用例是什么,但您可能希望分别在批处理中连接功能和标签的张量,如下所示:

def concat_batches(x, y):
    features1, labels1 = x
    features2, labels2 = y
    return ({feature: tf.concat([features1[feature], features2[feature]], axis=0) for feature in features1.keys()}, tf.concat([labels1, labels2], axis=0))

这是一个例子:

dataset = tf.data.Dataset.from_tensor_slices(({"test": [[1], [1], [1], [1]]}, [1, 1, 1, 1]))
b1 = dataset.repeat().batch(3).make_one_shot_iterator().get_next()
dataset2 = tf.data.Dataset.from_tensor_slices(({"test": [[2], [2], [2], [2]]}, [2, 2, 2, 2]))
b2 = dataset2.repeat().batch(3).make_one_shot_iterator().get_next()

b_con = concat_batches(b1, b2) #tensors of batches 1 and 2 have shape (3, 1), features of the concatenated batch (6, 1)

在评估您将看到的示例时,b_con将如下所示:

({'test': array([[1],
       [1],
       [1],
       [2],
       [2],
       [2]], dtype=int32)}, array([1, 1, 1, 2, 2, 2], dtype=int32))

希望这有帮助!