Keras model.fit()引发有关未指定参数steps_per_epoch的错误

时间:2019-06-20 13:52:59

标签: python tensorflow keras

我正在尝试将tf.Dataset作为我的数据集来拟合Keras模型。我指定了参数steps_per_epoch。但是,会引发此错误: ValueError: When using iterators as input to a model, you should specify the 'steps_per_epoch' argument.这个错误使我感到困惑,因为我为数据集的长度指定了steps_per_epoch参数。我尝试过None以及小于我的数据集长度的整数都无济于事。

这是我的代码:

def build_model():
    '''
    Function to build a LSTM RNN model that takes in quantitiy, converted week; outputs predicted price
    '''
    # define model
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.LSTM(128, activation='relu', input_shape=(num_steps,num_features*input_size)))
    model.add(tf.keras.layers.Dense(128*input_size, input_shape=(num_steps,num_features*input_size)))
    model.add(tf.keras.layers.Dense(input_size))
    model.compile(optimizer='adam', loss='mse')
    print(train_data[0].shape, train_data[1].shape)

    #cast data
    features_type = tf.float32
    target_type = tf.float32

    train_dataset = tf.data.Dataset.from_tensor_slices((
        tf.cast(train_data[0], features_type),
        tf.cast(train_data[1], target_type))
    )

    validation_dataset = tf.data.Dataset.from_tensor_slices((
        tf.cast(val_data[0], features_type),
        tf.cast(val_data[1], target_type))
    )

    # fit model
    es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1)
    model.fit(train_dataset, epochs=500,steps_per_epoch = 134,verbose=1, validation_data = validation_dataset)
    # validation_data = (val_data[0], val_data[1])
    print(model.summary())
    return model

1 个答案:

答案 0 :(得分:0)

您的 train_dataset validation_dataset 是数据集(请查看函数 from_tensor_slices 的tensorflow文档):https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices

我认为您需要使用数据集中的数据。例如,您可以使用以下函数在整个数据集上进行迭代:

iterator =数据集.make_one_shot_iterator()

看看tensorflow有关如何使用数据集对象中的数据的文档:https://www.tensorflow.org/guide/datasets#batching_dataset_elements

还请参阅这篇标题为如何正确组合TensorFlow的数据集API和Keras的帖子?How to Properly Combine TensorFlow's Dataset API and Keras?