将tf.Dataset输入到fit()时出错:KeyError:'embedding_input'

时间:2019-11-01 22:22:14

标签: python tensorflow keras tensorflow-datasets tensorflow2.0

我正在使用TensorFlow 2.0数据集来提供模型的拟合函数。这是代码:

def build_model(self):
    self.g_Model = Sequential()
    self.g_Model.add(Embedding(self.g_Max_features, output_dim=256))
    self.g_Model.add(LSTM(128))
    self.g_Model.add(Dropout(0.5))
    self.g_Model.add(Dense(1, activation='sigmoid'))
    self.g_Model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

def train_model(self, filenames):
    lstm_feature_description = {
        'X': tf.io.FixedLenFeature(CONFIG.g_keras_lstm_max_document_length, tf.float32),
        'y': tf.io.FixedLenFeature((), tf.int64),
    }

    def _parse_lstm_function(example_proto):
        return tf.io.parse_single_example(serialized=example_proto, features=lstm_feature_description)

    self.build_model()

    # Start Preparing The Data
    raw_lstm_dataset = tf.data.TFRecordDataset(CONFIG.g_record_file_lstm)

    parsed_lstm_dataset = raw_lstm_dataset.map(_parse_lstm_function)
    parsed_lstm_dataset = parsed_lstm_dataset.shuffle(CONFIG.g_shuffle_s).batch(CONFIG.g_Batch_size)

    self.g_Model.fit(parsed_lstm_dataset, epochs=2)

但是我收到以下错误:

Traceback (most recent call last):
  File "keras_lstm_v2.py", line 79, in train_model
      1/Unknown - 0s 0s/step    self.g_Model.fit(parsed_lstm_dataset, epochs=2)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 324, in fit
    total_epochs=epochs)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 123, in run_one_epoch
    batch_outs = execution_function(iterator)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 86, in execution_function
    distributed_function(input_fn))
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 503, in _call
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 408, in _initialize
    *args, **kwds))
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 1848, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 66, in distributed_function
    model, input_iterator, mode)
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 118, in _prepare_feed_values
    inputs = [inputs[key] for key in model._feed_input_names]
  File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 118, in <listcomp>
    inputs = [inputs[key] for key in model._feed_input_names]
KeyError: 'embedding_input'

我已经看过这个thread,但是它并没有为我澄清问题。据我了解,加载的数据存在问题,但是根据数据集文档,它应该是开箱即用的,所以我不知道该如何解决。

感谢您的帮助。谢谢!

2 个答案:

答案 0 :(得分:0)

您需要在某个地方声明模型的输入,通常是

model = tf.keras.Model(inputs=inputs, outputs=outputs)

尝试取出模型构建功能的最后一行

self.g_Model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

并将其移至训练函数,一旦声明了模型,就可以将模型的输出用作输出层,并使用

声明输入层
 input_size = CONFIG.g_keras_lstm_max_document_length
 input_layer = tf.keras.layers.Input(input_size)
 output_layer = self.build_model()
 model = tf.keras.Model(inputs=input_layer, outputs=output_layer )

 model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 
 model.fit( .... ) 

答案 1 :(得分:0)

我今天自己找到了解决方案。实际上,这是我的代码中的两个错误:

  1. .fit()函数无法确定输入和标签的位置,如上面@NiallJG所述。但是,提供的解决方案无法解决问题,因此我通过以下方式对其进行了修复:

    1.1在我的 build_model 函数中,我已将“名称”添加到嵌入层:

    self.g_Model.add(Embedding(input_dim=self.g_Max_features, output_dim=256, name='X'))
    

    1.2要匹配此名称,我实际上需要更改我的 lstm_feature_description ,因此它包含“ _input”后缀:

    def train_model(self, filenames):
    lstm_feature_description = {
        'X_input': tf.io.FixedLenFeature(CONFIG.g_keras_lstm_max_document_length, tf.float32),
        'y': tf.io.FixedLenFeature((), tf.int64),
    }
    
  2. 我的 _parse_lstm_function 正在将数据返回到 错误的方式,导致“ IndexError:列表索引超出范围”错误。修改后的函数如下所示:

    def _parse_lstm_function(example_proto):
        # Parse the input tf.Example proto using the dictionary above.
        parsed = tf.io.parse_single_example(serialized=example_proto, features=lstm_feature_description)
        return parsed["X_input"], parsed["y"]
    

这使模型可以正确地 .fit(),除非我现在遇到OOM错误,但这将在另一个问题中解决。