使用tf.data.Dataset时可以连接 批量数据集的方式不是第二个数据集 在第一个结尾处连接,但是这样 第一批第二个数据集后连接 第一批第二个数据集,依此类推。
我尝试了如下,但这给了我一个长度为40的数据集, 但是,我希望这里的长度为80.
train_data = train_data.batch(40).concatenate(augmentation_data.batch(40))
答案 0 :(得分:3)
不完全确定您的用例是什么,但您可能希望分别在批处理中连接功能和标签的张量,如下所示:
def concat_batches(x, y):
features1, labels1 = x
features2, labels2 = y
return ({feature: tf.concat([features1[feature], features2[feature]], axis=0) for feature in features1.keys()}, tf.concat([labels1, labels2], axis=0))
这是一个例子:
dataset = tf.data.Dataset.from_tensor_slices(({"test": [[1], [1], [1], [1]]}, [1, 1, 1, 1]))
b1 = dataset.repeat().batch(3).make_one_shot_iterator().get_next()
dataset2 = tf.data.Dataset.from_tensor_slices(({"test": [[2], [2], [2], [2]]}, [2, 2, 2, 2]))
b2 = dataset2.repeat().batch(3).make_one_shot_iterator().get_next()
b_con = concat_batches(b1, b2) #tensors of batches 1 and 2 have shape (3, 1), features of the concatenated batch (6, 1)
在评估您将看到的示例时,b_con将如下所示:
({'test': array([[1],
[1],
[1],
[2],
[2],
[2]], dtype=int32)}, array([1, 1, 1, 2, 2, 2], dtype=int32))
希望这有帮助!