将其他输入传递给Keras深度网络以计算自定义成本函数

时间:2018-12-25 10:41:15

标签: python tensorflow keras

我们的输入图片(input_x)为1000x512x512x1,权重图(input_w)为1000x512x512x1。实际上,每个图像都有自己的权重图,该权重图是在网络运行之前生成的,因此我们必须将其作为第二个输入。两者都馈入网络,尽管这些权重图仅用于与损失函数相乘,而不是真实的张量(它们不来自任何层,并且在达到损失函数之前一直是模型的输入)。首先,模型有两个输入,只有一个输出:

 model = keras.models.Model(inputs=[input_x, input_w], outputs=final_output)

和输入形状在网络的开头进行了更改:

input_x = layers.Input(shape=(512,512,1))
input_w = layers.Input(shape=(512,512,1))

input_x穿过网络层,但是input_w仅在customLoss中使用:

model.compile(optimizer=optimizer, loss=customLoss(input_w), metrics=[dice_coef, mean_iou])

由于input_w的附加参数而成为包装器:

def customLoss(input_w): 
  def loss_fcn(y_true, y_pred):
     bce = keras.losses.binary_crossentropy(y_true, y_pred)
     dice_term = K.exp(1 + dice_coef(y_true, y_pred, 1.0))
     return input_w * (bce - dice_term)
  return loss_fcn

从数据集中产生X和W后,我们称其为2个输入,X是input_x(图像),W是(权重图)。

history= my_model.fit([X,W],y,validation_split=0.1, epochs=5000,batch_size=8, callbacks=[best_check])

一切对我来说都是正确的,但我收到

的错误
Epoch 1/5000
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-11-b756059772c5> in <module>()
      6                               patience=6,
      7                               verbose=1, mode='auto')
----> 8 history= my_model.fit([X,W],y,validation_split=0.1, epochs=5000,batch_size=8, callbacks=[best_check])

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    197                     ins_batch[i] = ins_batch[i].toarray()
    198 
--> 199                 outs = f(ins_batch)
    200                 outs = to_list(outs)
    201                 for l, o in zip(out_labels, outs):

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1437           ret = tf_session.TF_SessionRunCallable(
   1438               self._session._session, self._handle, args, status,
-> 1439               run_metadata_ptr)
   1440         if run_metadata:
   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: Incompatible shapes: [8,512,512,1] vs. [8,512,512]
     [[{{node loss_1/mask_output_loss/mul_2}} = Mul[T=DT_FLOAT, _class=["loc:@training_1/Adam/gradients/loss_1/mask_output_loss/mul_2_grad/Reshape_1"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](_arg_input_4_0_1/_3633, loss_1/mask_output_loss/sub_2)]]
     [[{{node metrics_1/mean_iou/mean_iou_2/confusion_matrix/assert_non_negative_1/assert_less_equal/Assert/AssertGuard/Assert/Switch/_4057}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_6795_...ert/Switch", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

没有传递权重图,一切正常,但现在出现此错误。

1 个答案:

答案 0 :(得分:0)

所有这些代码都是真实编写的,但是,由于该项目是在google Colab中开发的,因此发生了一些奇怪的错误,因此,在几次重新连接页面后,问题才得以解决! 该错误可能是由于断开连接造成的!