GRU切换到CuDNNGRU有未知输入节点的错误

时间:2018-04-24 15:51:39

标签: python-3.x tensorflow recurrent-neural-network keras-layer

您好我正在Keras尝试GRU和CuDNNGRU模型。 GRU模型可以完美地工作。但当我切换到CuDNNGRU时,错误指示。这是我的代码

def get_model():
input_words = Input((maxlen, ))
x_words = Embedding(max_features, 300,
                        weights=[embedding_matrix],
                        trainable=False)(input_words)
#x_words = SpatialDropout1D(0.5)(x_words)
x_words =Bidirectional(GRU(50, return_sequences=True))(x_words)
#x_words = Convolution1D(100, 3, activation="relu")(x_words)
x_words = GlobalMaxPool1D()(x_words)    
x = Dense(50, activation="relu")(x_words)
x = Dropout(0.25)(x_words)
predictions = Dense(6, activation="sigmoid")(x_words)
model = Model(inputs=input_words, outputs=predictions)
model.compile(optimizer=optimizers.Adam(0.0005, decay=1e-6),
          loss='binary_crossentropy',
          metrics=['accuracy'])

return model

当我运行GRU时,它完美无缺。但我把它改成了CuDNNGRU,它会显示错误。我不确定因为Keras确实需要更多参数,或者我不能在这里使用CuDNNGRU。我应该更深入地了解Tensorflow吗?任何建议表示赞赏。感谢

    ---------------------------------------------------------------------------
     InvalidArgumentError                      Traceback (most recent call last)
     /opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py 
     in _do_call(self, fn, *args)
       1326     try:
     -> 1327       return fn(*args)
        1328     except errors.OpError as e:

      /opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py 
      in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
       1309       # Ensure any changes to the graph are reflected in the 
     runtime.
     -> 1310       self._extend_graph()
       1311       return self._call_tf_sessionrun(

     /opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _

 - List item

extend_graph(self)
   1357             tf_session.TF_ExtendGraph(self._session,
-> 1358                                       graph_def.SerializeToString(), status)
   1359           self._opened = True

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    515             compat.as_text(c_api.TF_Message(self.status.status)),
--> 516             c_api.TF_GetCode(self.status.status))
    517     # Delete the underlying status object from memory otherwise it stays alive

InvalidArgumentError: Node 'embedding_8/IsVariableInitialized': Unknown input node 'bidirectional_3/forward_cu_dnngru_1/kernel'

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-39-98e1902305d7> in <module>()
----> 1 model = get_model()

<ipython-input-38-8788d950b75d> in get_model()
      7     x_words = Embedding(max_features, 300,
      8                             weights=[embedding_matrix],
----> 9                             trainable=False)(input_words)
     10     #x_words = SpatialDropout1D(0.5)(x_words)
     11     x_words =Bidirectional(GRU(50, return_sequences=True))(x_words)

/opt/conda/lib/python3.6/site-packages/Keras-2.1.5-py3.6.egg/keras/engine/topology.py in __call__(self, inputs, **kwargs)
    597                 # Load weights that were specified at layer instantiation.
    598                 if self._initial_weights is not None:
--> 599                     self.set_weights(self._initial_weights)
    600 
    601             # Raise exceptions in case the input is not compatible

/opt/conda/lib/python3.6/site-packages/Keras-2.1.5-py3.6.egg/keras/engine/topology.py in set_weights(self, weights)
   1211             return
   1212         weight_value_tuples = []
-> 1213         param_values = K.batch_get_value(params)
   1214         for pv, p, w in zip(param_values, params, weights):
   1215             if pv.shape != w.shape:

/opt/conda/lib/python3.6/site-packages/Keras-2.1.5-py3.6.egg/keras/backend/tensorflow_backend.py in batch_get_value(ops)
   2325     """
   2326     if ops:
-> 2327         return get_session().run(ops)
   2328     else:
   2329         return []

/opt/conda/lib/python3.6/site-packages/Keras-2.1.5-py3.6.egg/keras/backend/tensorflow_backend.py in get_session()
    191                 # not already marked as initialized.
    192                 is_initialized = session.run(
--> 193                     [tf.is_variable_initialized(v) for v in candidate_vars])
    194                 uninitialized_vars = []
    195                 for flag, v in zip(is_initialized, candidate_vars):

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    903     try:
    904       result = self._run(None, fetches, feed_dict, options_ptr,
--> 905                          run_metadata_ptr)
    906       if run_metadata:
    907         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1138     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1139       results = self._do_run(handle, final_targets, final_fetches,
-> 1140                              feed_dict_tensor, options, run_metadata)
   1141     else:
   1142       results = []

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1319     if handle is None:
   1320       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1321                            run_metadata)
   1322     else:
   1323       return self._do_call(_prun_fn, handle, feeds, fetches)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1338         except KeyError:
   1339           pass
-> 1340       raise type(e)(node_def, op, message)
   1341 
   1342   def _extend_graph(self):

InvalidArgumentError: Node 'embedding_8/IsVariableInitialized': Unknown input node 'bidirectional_3/forward_cu_dnngru_1/kernel'


history = model.fit( X_train_words, X_train_target, valida
    will be stripped off, but all other whitespace will be preserved.

1 个答案:

答案 0 :(得分:0)

我之前有过它,是因为您没有配置Cudnn ..