使用数据集迭代器作为张量流模型的输入时出错。

时间:2020-06-10 21:00:25

标签: python tensorflow keras tensorflow2.0

当我运行此代码时:

x_train = tfds.load('ucf101', split='train', shuffle_files=True, batch_size = 64)

dim = lambda x: x['video'][:,30:40, ...]
x_train = x_train.map(dim)

model.compile(loss='mse',
              optimizer=tf.keras.optimizers.Adam(1e-4),
              metrics=['accuracy'])

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

history = model.fit(x_train, x_train,
                    epochs=100,
                    verbose=1,
                    callbacks=[cp_callback])

我收到错误消息:

ValueError: You passed a dataset or dataset iterator (<MapDataset shapes: (None, None, 256, 256, 3), types: tf.uint8>) as input x to your model. In that case, you should not specify a target ( y ) argument, since the dataset or dataset iterator generates both input data and target data. Received: <MapDataset shapes: (None, None, 256, 256, 3), types: tf.uint8>

这是一个自动编码器,因此有意提供x_train作为输入和目标。 MapDataset的尺寸为(批,框架,高度,宽度,RGB),并且不包含任何目标数据。

1 个答案:

答案 0 :(得分:0)

通过https://www.tensorflow.org/api_docs/python/tf/keras/Model

如果x是数据集,生成器或keras.utils.Sequence实例,则不应指定y(因为将从迭代器/数据集获取目标)。

如果要使用model.fit(),则需要在tf.data数据集中指定目标,该目标可能类似于:

object.vpnPort = response.data