如果tf.io.gfile.glob在函数之外,则与tf.data.TFRecordDataset的dataset.interleave导致平坦丢失和准确性。

时间:2020-05-10 20:25:19

标签: python-3.x google-cloud-storage cloud tensorflow2.0

我将Tensorflow 2.2.0python 3.7一起使用。我标记了一些文本并将tf.data.dasets保存为TFRecord文件。这是一个小的数据集(SST2,67k带有小文本的条目),我只有一个.tfrecord文件。我有一个使用Keras和Huggingface的标准分类模型。代码和模型正在本地和GCP上运行。

def build_dataset(input_tfrecords, batch_size, shuffle_buffer=2048):
    dataset = tf.data.Dataset.list_files(file_pattern)
    dataset = dataset.interleave(tf.data.TFRecordDataset,
                                  cycle_length=tf.data.experimental.AUTOTUNE,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(shuffle_buffer)
    dataset = dataset.map(pp.parse_tfrecord_glue_files, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset 


train_files = tf.io.gfile.glob(FLAGS.input_train_tfrecords+'/*.tfrecord')

train_dataset = tf_bert.build_dataset(train_files, FLAGS.batch_size_train, 2048)

这给模型带来了一些奇怪的行为。损失和准确性均保持不变! enter image description here

现在,如果我将列出文件的函数放入函数中,它将按预期工作:

def build_dataset(input_tfrecords, batch_size, shuffle_buffer=2048):
    file_pattern = tf.io.gfile.glob(input_tfrecords+'/*.tfrecord')
    dataset = tf.data.Dataset.list_files(file_pattern)
    dataset = dataset.interleave(tf.data.TFRecordDataset,
                                  cycle_length=tf.data.experimental.AUTOTUNE,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(shuffle_buffer)
    dataset = dataset.map(pp.parse_tfrecord_glue_files, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

    train_dataset = tf_bert.build_dataset(FLAGS.input_train_tfrecords, FLAGS.batch_size_train, 2048)

现在模型正在收敛:
enter image description here

我正在云上运行,每个示例都是一个新的实验。我真的很想念可能导致模型什么都不学的原因。除了放置列出TFRecod文件的功能之外,其他所有方面都是相同的。函数和调用在2个不同的文件中完成。

在进行优化之前,我正在使用它,即使列出文件的功能不在功能范围内,它也可以正常工作:

dataset = tf.data.TFRecordDataset(file_pattern)
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.map(pp.parse_tfrecord_glue_files, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE

似乎是interleave的问题所在:

dataset = tf.data.Dataset.list_files(file_pattern)
dataset = dataset.interleave(tf.data.TFRecordDataset ...

使用时遇到相同的问题:

dataset = tf.data.Dataset.from_tensor_slices(file_pattern)
dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x) ...

问题1:为什么tf.io.gfile.glob必须在函数内部?
问题2:数据如何以模型不学习的任何方式馈入模型?
问题3:还有其他建议作为“最佳实践”吗?

在我的实验中,数据存储在GCP存储中,但在本地计算机上却看到相同的数据。这给出了完全相同的输出:

list(tf.data.Dataset.list_files(tf.io.gfile.glob(input_tfrecords+'/*.tfrecord')))
list(tf.data.Dataset.list_files(input_tfrecords+'/*.tfrecord'))

0 个答案:

没有答案