传递无限重复的数据集时,必须指定`steps_per_epoch`参数

时间:2019-11-13 17:55:39

标签: python tensorflow tensorflow-datasets tensorflow2.0 tensorflow-lite

我正在尝试使用此Google的示例,但使用了自己的数据集:

https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/text_classification.ipynb

我创建了一个文件夹,该文件夹与火车和测试文件夹以及txt文件中的代码类似。

在我的情况下,data_path如下: data_path = '/Users/developer/.keras/datasets/chat'

每当我尝试运行它时,model = text_classifier.create(train_data)都会引发错误 ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument. 那甚至意味着什么,我应该在哪里寻找问题?


import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_customization.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_customization.core.model_export_format import ModelExportFormat
import tensorflow_examples.lite.model_customization.core.task.text_classifier as text_classifier


# data_path = tf.keras.utils.get_file(
#       fname='aclImdb',
#       origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
#       untar=True)

data_path = '/Users/developer/.keras/datasets/chat'

train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), class_labels=['greeting', 'goodbye'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), shuffle=False)

model = text_classifier.create(train_data)
loss, acc = model.evaluate(test_data)
model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')

2 个答案:

答案 0 :(得分:1)

我遇到了类似的问题,然后在model.fit下我添加了steps_per_epoch

history = single_step_model.fit(train_data_single,
                                epochs=100, 
                                callbacks=[lr_schedule], 
                                steps_per_epoch=EVALUATION_INTERVAL)

当然,我在此之前输入了EVALUATION_INTERVA L的值,因此它起作用了。希望对您有所帮助。

答案 1 :(得分:0)

问题是,当您为所需的时期数训练模型时,您的训练代码部分可能无法确定特定时期的开始时间和结束时间。

因此,在训练过程中,可以添加“ steps_per_epoch”参数,这样它将知道如何针对单个时期的特定有限步数进行操作和训练。

在进行验证的情况下,我们添加了特定的“ validation_steps”来解决同一问题。

model.fit

如上所述,我通过在我的tf.Keras model.fit()代码中添加了steps_per_epoch和validation_steps参数来解决了这个问题。

需要总结一下如何在代码中提供这些参数。

参考: