如何获得具有在keras中多个输入的模型的中间输出?

时间:2019-05-07 11:42:58

标签: python keras neural-network deep-learning

我有一个代码如下的模型。 cat_inputs是一个4x3495数组(我的数据中有3495行):

model_cat_inps = [Input(shape=(1,)) for _ in cat_inps]
model_cont_inp = Input(shape=(1, 37), name='cont_inp')

embeddings = [Embedding(input_dim=len(np.unique(x)),
                        output_dim=round(1.6 * len(np.unique(x)) ** 0.56)
                       )(y) for x, y in zip(cat_inps, model_cat_inps)]
bn1 = BatchNormalization(name='first_bn')(model_cont_inp)

concat = keras.layers.concatenate([*embeddings, bn1], name='concatenate')

relu = Dense(5, activation='relu', name='dense1')(concat)
bn = BatchNormalization(name='bn1')(relu)
drop = Dropout(0.2, name='dropout1')(bn)

relu = Dense(5, activation='relu', name='dense2')(drop)
bn = BatchNormalization()(relu)
drop = Dropout(0.2)(bn)

flat = Flatten()(drop)
out = Dense(3, activation='softmax', name='dense3')(flat)

model = Model(inputs=[*model_cat_inps, model_cont_inp], outputs=out) 

我需要获取倒数第二层(Flatten)的输出。为此,我定义了一个函数,如下所示:

func = K.function(model.inputs, [model.layers[-2].output])  

要获得中间输出结果,我调用函数:

Kz = func([*cat_inps, np.array(train_df[cont_vars]).reshape(3495, 1, 37)])  

此函数调用是传递给LearningRateScheduler对象的函数的一部分,因此,当我调用model.fit时,出现以下错误:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-161-dcb47104ce2d> in <module>
----> 1 model.fit(x=[*cat_inps, np.array(train_df[cont_vars]).reshape(3495, 1, 37)], y=y, epochs=1, callbacks=[lr_scheduler])

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

/usr/local/lib/python3.5/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    144         for m in model.stateful_metric_functions:
    145             m.reset_states()
--> 146         callbacks.on_epoch_begin(epoch)
    147         epoch_logs = {}
    148         if steps_per_epoch is not None:

/usr/local/lib/python3.5/dist-packages/keras/callbacks.py in on_epoch_begin(self, epoch, logs)
     63         logs = logs or {}
     64         for callback in self.callbacks:
---> 65             callback.on_epoch_begin(epoch, logs)
     66         self._delta_t_batch = 0.
     67         self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)

/usr/local/lib/python3.5/dist-packages/keras/callbacks.py in on_epoch_begin(self, epoch, logs)
    651         lr = float(K.get_value(self.model.optimizer.lr))
    652         try:  # new API
--> 653             lr = self.schedule(epoch, lr)
    654         except TypeError:  # old API for backward compatibility
    655             lr = self.schedule(epoch)

<ipython-input-157-de7e7cdbe2a1> in lr_schedule(epoch, _)
      6         lr (float32): learning rate
      7     """
----> 8     Kz = func([*cat_inps, np.array(train_df[cont_vars]).reshape(3495, 1, 37)])
      9     print(Kz)
     10 

/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1437           ret = tf_session.TF_SessionRunCallable(
   1438               self._session._session, self._handle, args, status,
-> 1439               run_metadata_ptr)
   1440         if run_metadata:
   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: ConcatOp : Expected concatenating dimensions in the range [-2, 2), but got 2
     [[{{node concatenate_1/concat}} = ConcatV2[N=5, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_5/embedding_lookup, embedding_6/embedding_lookup, embedding_7/embedding_lookup, embedding_8/embedding_lookup, first_bn_1/cond/Merge, concatenate_1/concat/axis)]]

我该如何解决?

0 个答案:

没有答案