我试图为GAN模型中的鉴别器实现自定义损失函数(三重损失),每当我尝试训练用下面编写的损失函数编译的任何模型时,都会出现此错误:< / p>
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: loss/add/
我一直在想着如何解决此问题,但没有成功。如果我使用预先构建的Keras损失函数(例如,binary_crossentropy)来编译鉴别器,则该模型能够训练,因此我坚信损失函数是造成此问题的原因。如果我在其他任何模型上尝试使用此方法,也会发生相同的错误,因此我认为该错误并非源于我的鉴别器的某些方面。
def dTF(a,b):
a = K.eval(a)
b = K.eval(b)
dist = np.linalg.norm(a-b)
return dist
def tripletLossTF(p,n):
margin = 0.5
a = p*0 # anchor point at origin
loss = max(dTF(a,p) - dTF(a,n) + margin, 0.0)
return loss
siamese_net.compile(loss=tripletLossTF,optimizer=Adam(0.00006), run_eagerly=True)
这是完整的错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-92-3b448b661c74> in <module>
----> 1 ST.train(20)
<ipython-input-85-57f75616611f> in train(self, iterations)
166 # Y_disc = np.array([0])
167 # print(X_disc.shape, Y_disc.shape)
--> 168 d_loss = self.discriminator.train_on_batch(X_disc, Y_disc)
169
170 # train generator
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
1076 self, x, y=y, sample_weight=sample_weight,
1077 class_weight=class_weight, reset_metrics=reset_metrics,
-> 1078 standalone=True)
1079 outputs = (outputs['total_loss'] + outputs['output_losses'] +
1080 outputs['metrics'])
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in train_on_batch(model, x, y, sample_weight, class_weight, reset_metrics, standalone)
431 y,
432 sample_weights=sample_weights,
--> 433 output_loss_metrics=model._output_loss_metrics)
434
435 if reset_metrics:
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in train_on_batch(model, inputs, targets, sample_weights, output_loss_metrics)
310 sample_weights=sample_weights,
311 training=True,
--> 312 output_loss_metrics=output_loss_metrics))
313 if not isinstance(outs, list):
314 outs = [outs]
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _process_single_batch(model, inputs, targets, output_loss_metrics, sample_weights, training)
251 output_loss_metrics=output_loss_metrics,
252 sample_weights=sample_weights,
--> 253 training=training))
254 if total_loss is None:
255 raise ValueError('The model cannot be run '
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _model_loss(model, inputs, targets, output_loss_metrics, sample_weights, training)
207 if custom_losses:
208 total_loss += losses_utils.scale_loss_for_distribution(
--> 209 math_ops.add_n(custom_losses))
210
211 return outs, total_loss, output_losses, masks
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py in binary_op_wrapper(x, y)
900 with ops.name_scope(None, op_name, [x, y]) as name:
901 if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 902 return func(x, y, name=name)
903 elif not isinstance(y, sparse_tensor.SparseTensor):
904 try:
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py in _add_dispatch(x, y, name)
1192 return gen_math_ops.add(x, y, name=name)
1193 else:
-> 1194 return gen_math_ops.add_v2(x, y, name=name)
1195
1196
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_math_ops.py in add_v2(x, y, name)
478 pass # Add nodes to the TensorFlow graph.
479 except _core._NotOkStatusException as e:
--> 480 _ops.raise_from_not_ok_status(e, name)
481 # Add nodes to the TensorFlow graph.
482 _, _, _op, _outputs = _op_def_library._apply_op_helper(
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in raise_from_not_ok_status(e, name)
6604 message = e.message + (" name: " + name if name is not None else "")
6605 # pylint: disable=protected-access
-> 6606 six.raise_from(core._status_to_exception(e.code, message), None)
6607 # pylint: enable=protected-access
6608
/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: loss/add/