我正在尝试为Keras(Tenserflow)输入多个输入(2个输入)的数据以进行TPU训练,但是出现此错误:
ValueError: The dataset returned a non-Tensor type
((<class 'tensorflow.python.framework.ops.Tensor'>,
<class 'tensorflow.python.framework.ops.Tensor'>)) at index 0
我尝试了以下链接:tf.data with multiple inputs / outputs in Keras
def train_input_fn(batch_size=1024):
dataset_features = tf.data.Dataset.from_tensor_slices((x_train_h, x_train_l))
dataset_label = tf.data.Dataset.from_tensor_slices(Y_train)
dataset = tf.data.Dataset.zip((dataset_features, dataset_label)).batch(batch_size, drop_remainder=True)
return dataset
history = tpu_model.fit(train_input_fn,
steps_per_epoch = 30,
epochs = 100,
validation_data = test_input_fn,
validation_steps = 1,
callbacks = [tensorboard])
[model]: https://take.ms/jO4P5
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-242-65c6d5a98fb7> in <module>()
12 validation_data = test_input_fn,
13 validation_steps = 1,
---> 14 callbacks = [tensorboard,
15 #checkpointer
16 ]
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1486 'be None')
1487 infeed_manager = TPUDatasetInfeedManager(
-> 1488 dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
1489 # Use dummy numpy inputs for the rest of Keras' shape checking. We
1490 # intercept them when building the model.
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in __init__(self, dataset, tpu_assignment, mode)
722 mode: ModeKeys enum.
723 """
--> 724 self._verify_dataset_shape(dataset)
725
726 self._dataset = dataset
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in _verify_dataset_shape(self, dataset)
783 if cls != ops.Tensor:
784 raise ValueError('The dataset returned a non-Tensor type (%s) at '
--> 785 'index %d.' % (cls, i))
786 for i, shape in enumerate(dataset.output_shapes):
787 if not shape:
ValueError: The dataset returned a non-Tensor type ((<class 'tensorflow.python.framework.ops.Tensor'>, <class 'tensorflow.python.framework.ops.Tensor'>)) at index 0.