tf.data.Dataset的输入形状不被model.fit()

时间:2020-07-13 14:14:11

标签: python tensorflow tensorflow-datasets

我想通过应用tf.data.Dataset来提供模型数据。

检查了TF 2.0的文档后,我发现.fit()函数(https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)接受:

x-一个tf.data数据集。应该返回其中一个的元组(输入,目标) 或(输入,目标,sample_weights)。

因此,我编写了以下概念证明的最小代码:

from sklearn.datasets import make_blobs
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.metrics import Accuracy, AUC

X, Y = make_blobs(n_samples=500, n_features=2, cluster_std=3.0, random_state=1)

def define_model():
    model = Sequential()
    model.add(Dense(units=1, activation="sigmoid", input_shape=(2,)))
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[AUC(), Accuracy()])
    return model

model = define_model()

X_ds = tf.data.Dataset.from_tensor_slices(X)
Y_ds = tf.data.Dataset.from_tensor_slices(Y)
dataset = tf.data.Dataset.zip((X_ds, Y_ds))

for elem in dataset.take(1):
    print(type(elem))
    print(elem)

model.fit(x=dataset) #<-- does not work
#model.fit(x=X, y=Y) <-- does work without any problems....

如第二条评论中所述,不应用tf.data.Dataset的代码可以正常工作。

但是,在应用数据集对象时,出现以下错误消息:

<class 'tuple'>
(<tf.Tensor: shape=(2,), dtype=float64, numpy=array([-10.42729974,  -0.85439721])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
... other output here...
ValueError: Error when checking input: expected dense_19_input to have
shape (2,) but got array with shape (1,)

根据我对文档的理解,我构建的数据集应该恰好是fit方法期望的元组对象。

我不明白此错误消息。

我在做什么错了?

1 个答案:

答案 0 :(得分:1)

将数据集传递到fit时,可以预期它将直接生成批次,而不是单个示例。您只需要在训练之前对数据集进行批处理即可。

dataset = dataset.batch(batch_size)
model.fit(x=dataset)