获取Tensorflow中数据集的长度

时间:2017-12-10 04:33:54

标签: python-3.x tensorflow dataset

source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

我想完全整理整个数据集,但shuffle()需要提取大量样本,而tf.Size()不能与tf.data.Dataset一起使用。

我怎样才能正常洗牌?

2 个答案:

答案 0 :(得分:1)

我正在使用tf.data.FixedLengthRecordDataset()并遇到了类似的问题。 在我的情况下,我试图只占一定比例的原始数据。 由于我知道所有记录都有固定的长度,因此我的解决方法是:

totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)

在你的情况下,我的建议是直接在python中计算'primary.csv'和'secondary.csv'中的记录数。或者,我认为为了您的目的,设置buffer_size参数并不需要计算文件。根据{{​​3}},一个大于数据集中元素数量的数字将确保整个数据集中的统一混洗。因此,只需输入一个非常大的数字(您认为将超过数据集大小)就可以了。

答案 1 :(得分:0)

从TensorFlow 2开始,可以通过cardinality()函数轻松地检索数据集的长度。

dataset = tf.data.Dataset.range(42)
#both print 42 
dataset_length_v1 = tf.data.experimental.cardinality(dataset).numpy())
dataset_length_v2 = dataset.cardinality().numpy()

注意:使用谓词(例如filter)时,长度的返回值可能为-2。您可以咨询一种解释here,否则只需阅读以下段落:

如果使用过滤谓词,则基数可能返回值-2,因此未知;如果您确实对数据集使用了过滤谓词,请确保以其他方式计算出数据集的长度(例如,在对数据集应用.from_tensor_slices()之前,熊猫数据框的长度。