存储元组Tensorflow数据集,其中每个元组元素具有不同的形状

时间:2019-03-24 01:42:11

标签: tensorflow deep-learning batch-processing tensorflow-datasets bucket

我正在尝试修改现有的tensorflow代码。首先,将二维单词矩阵从dataset转换为geneartor,并通过map_strings_to_ints函数转换为vocab索引。然后调用以下函数。

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=lambda d: tf.shape(d)[0],
                                                                     bucket_boundaries=bucket_boundaries,
                                                                     bucket_batch_sizes=bucket_batch_sizes,
                                                                     padded_shapes=dataset.output_shapes,
                                                                     padding_values=constants.PAD_VALUE))

其中每个dataset元素都是一个大小为[None,None](即2d mat)的数组。

现在,对于每个元素,我想添加另一个文本序列。因此,每个元素都是先前二维垫的元组,而每个新数据集元素的对应句子/序列是([None,None],[None])的元组,那么如何修改上述函数? / p>

我尝试过

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=lambda d,t: tf.shape(d)[0],
                                                                     bucket_boundaries=bucket_boundaries,
                                                                     bucket_batch_sizes=bucket_batch_sizes,
                                                                     padded_shapes=dataset.output_shapes,
                                                                     padding_values=constants.PAD_VALUE))

还有其他一些技巧,但是得到了

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class ‘int’>

请注意,dataset元素是映射到vocab索引(即int)的单词

1 个答案:

答案 0 :(得分:0)

这应该对您有帮助:

X = np.array([[[1,2,3],[4,5,6]],[[7,8,9], [1,2,3], [4,5,6], [7,8,9]], [[1,2,3], [4,5,6]]])
Y = np.array([0,1,0])

def elements_gen():
    for x,y in zip(X,Y):
        yield (x,y)

dataset = tf.data.Dataset.from_generator(generator=elements_gen, output_shapes=([None, None], []), output_types=(tf.int32, tf.int32))

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_fun =lambda x,y: tf.shape(x)[0], bucket_boundaries=[4,7], bucket_batch_sizes=[2,2,2], padding_values=(0,0)))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

问题恰恰是错误所说明的,因为要填充的结构是一个序列,所以用于填充该结构的值也必须是一个序列。