如何为Tensorflow的数据输入tpu的多个输入?

时间:2019-04-18 17:24:00

标签: python tensorflow

我正在尝试为Keras(Tenserflow)输入多个输入(2个输入)的数据以进行TPU训练,但是出现此错误:

ValueError: The dataset returned a non-Tensor type
  ((<class 'tensorflow.python.framework.ops.Tensor'>,
    <class 'tensorflow.python.framework.ops.Tensor'>)) at index 0

我尝试了以下链接:tf.data with multiple inputs / outputs in Keras

def train_input_fn(batch_size=1024):    
    dataset_features = tf.data.Dataset.from_tensor_slices((x_train_h, x_train_l))
    dataset_label = tf.data.Dataset.from_tensor_slices(Y_train)                                                         
    dataset = tf.data.Dataset.zip((dataset_features, dataset_label)).batch(batch_size, drop_remainder=True)

    return dataset

history = tpu_model.fit(train_input_fn,                                                
                     steps_per_epoch = 30,
                     epochs = 100,                        
                     validation_data = test_input_fn,                        
                     validation_steps = 1,
                     callbacks = [tensorboard])                                                                  
[model]: https://take.ms/jO4P5
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-242-65c6d5a98fb7> in <module>()
    12                          validation_data = test_input_fn,
    13                          validation_steps = 1,
---> 14                          callbacks = [tensorboard, 
    15                                       #checkpointer
    16                                      ]

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
  1486                              'be None')
  1487           infeed_manager = TPUDatasetInfeedManager(
-> 1488               dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
  1489           # Use dummy numpy inputs for the rest of Keras' shape checking. We
  1490           # intercept them when building the model.

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in __init__(self, dataset, tpu_assignment, mode)
   722       mode: ModeKeys enum.
   723     """
--> 724     self._verify_dataset_shape(dataset)
   725 
   726     self._dataset = dataset

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in _verify_dataset_shape(self, dataset)
   783       if cls != ops.Tensor:
   784         raise ValueError('The dataset returned a non-Tensor type (%s) at '
--> 785                          'index %d.' % (cls, i))
   786     for i, shape in enumerate(dataset.output_shapes):
   787       if not shape:

ValueError: The dataset returned a non-Tensor type ((<class 'tensorflow.python.framework.ops.Tensor'>, <class 'tensorflow.python.framework.ops.Tensor'>)) at index 0.

0 个答案:

没有答案