Tensorflow如何生成不平衡的组合数据集

时间:2018-01-15 23:18:43

标签: python tensorflow

我对新数据集API(tensorflow 1.4)有疑问。我有两个数据集,我需要创建一个组合的不平衡数据集,即 每个批次应包含来自第一个的一定数量的元素和来自第二个数据集的一定数量的元素。例如,

dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([1,1,1,1,1,1]
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([2,2,2,2,2,2]))

假设批量大小为4我希望组合数据集中的批处理看起来像[1,1,1,2]。我知道如何使用zip和flat_map生成平衡数据集 但我对此感到茫然。

提前致谢!

1 个答案:

答案 0 :(得分:2)

为了解决这个问题,我的解决方案是单独批处理数据集,压缩它们,然后在生成的数据集上映射tf.concat运算符。

在您的示例中,它会提供类似的内容(我重命名了第二个数据集dataset2):

def concat(*tensor_list):
    return tf.concat(tensor_list, axis=0)

zipped_ds = tf.data.Dataset.zip((dataset1.batch(3), dataset2))
unbalanced_ds = zipped_ds.map(concat)

如果数据集是张量的嵌套结构,则可以使用以下版本的concat:

def concat(*ds_elements):
    #Create one empty list for each component of the dataset
    lists = [[] for _ in ds_elements[0]]
    for element in ds_elements:
        for i, tensor in enumerate(element):
            #For each element, add all its component to the associated list
            lists[i].append(tensor)

    #Concatenate each component list
    return tuple(tf.concat(l, axis=0) for l in lists)

如果所有数据集元素(要组合的数据集的一部分)是仅最外层维度(相对批量大小)不同的张量,则可以使用。它为数据集元素的每个组件构建一个列表,并将这些组件彼此连接起来。

哪个处理一级嵌套。如果你需要更多,你可以使用重复来去解包嵌套的嵌套,但它可能会给出一个不那么干净的计算图...