尝试使用自定义损失函数时遇到InvalidArgumentError

时间:2020-06-04 02:53:02

标签: python python-3.x tensorflow keras loss-function

我试图为GAN模型中的鉴别器实现自定义损失函数(三重损失),每当我尝试训练用下面编写的损失函数编译的任何模型时,都会出现此错误:< / p>

InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: loss/add/

我一直在想着如何解决此问题,但没有成功。如果我使用预先构建的Keras损失函数(例如,binary_crossentropy)来编译鉴别器,则该模型能够训练,因此我坚信损失函数是造成此问题的原因。如果我在其他任何模型上尝试使用此方法,也会发生相同的错误,因此我认为该错误并非源于我的鉴别器的某些方面。

def dTF(a,b):
    a = K.eval(a)
    b = K.eval(b)
    dist = np.linalg.norm(a-b)
    return dist

def tripletLossTF(p,n):
    margin = 0.5
    a = p*0 # anchor point at origin
    loss = max(dTF(a,p) - dTF(a,n) + margin, 0.0)
    return loss
siamese_net.compile(loss=tripletLossTF,optimizer=Adam(0.00006), run_eagerly=True)

这是完整的错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-92-3b448b661c74> in <module>
----> 1 ST.train(20)

<ipython-input-85-57f75616611f> in train(self, iterations)
    166 #             Y_disc = np.array([0])
    167             # print(X_disc.shape, Y_disc.shape)
--> 168             d_loss = self.discriminator.train_on_batch(X_disc, Y_disc)
    169 
    170             # train generator

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
   1076           self, x, y=y, sample_weight=sample_weight,
   1077           class_weight=class_weight, reset_metrics=reset_metrics,
-> 1078           standalone=True)
   1079       outputs = (outputs['total_loss'] + outputs['output_losses'] +
   1080                  outputs['metrics'])

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in train_on_batch(model, x, y, sample_weight, class_weight, reset_metrics, standalone)
    431       y,
    432       sample_weights=sample_weights,
--> 433       output_loss_metrics=model._output_loss_metrics)
    434 
    435   if reset_metrics:

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in train_on_batch(model, inputs, targets, sample_weights, output_loss_metrics)
    310           sample_weights=sample_weights,
    311           training=True,
--> 312           output_loss_metrics=output_loss_metrics))
    313   if not isinstance(outs, list):
    314     outs = [outs]

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _process_single_batch(model, inputs, targets, output_loss_metrics, sample_weights, training)
    251               output_loss_metrics=output_loss_metrics,
    252               sample_weights=sample_weights,
--> 253               training=training))
    254       if total_loss is None:
    255         raise ValueError('The model cannot be run '

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _model_loss(model, inputs, targets, output_loss_metrics, sample_weights, training)
    207     if custom_losses:
    208       total_loss += losses_utils.scale_loss_for_distribution(
--> 209           math_ops.add_n(custom_losses))
    210 
    211   return outs, total_loss, output_losses, masks

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py in binary_op_wrapper(x, y)
    900     with ops.name_scope(None, op_name, [x, y]) as name:
    901       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 902         return func(x, y, name=name)
    903       elif not isinstance(y, sparse_tensor.SparseTensor):
    904         try:

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py in _add_dispatch(x, y, name)
   1192     return gen_math_ops.add(x, y, name=name)
   1193   else:
-> 1194     return gen_math_ops.add_v2(x, y, name=name)
   1195 
   1196 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_math_ops.py in add_v2(x, y, name)
    478         pass  # Add nodes to the TensorFlow graph.
    479     except _core._NotOkStatusException as e:
--> 480       _ops.raise_from_not_ok_status(e, name)
    481   # Add nodes to the TensorFlow graph.
    482   _, _, _op, _outputs = _op_def_library._apply_op_helper(

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6604   message = e.message + (" name: " + name if name is not None else "")
   6605   # pylint: disable=protected-access
-> 6606   six.raise_from(core._status_to_exception(e.code, message), None)
   6607   # pylint: enable=protected-access
   6608 

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: loss/add/

0 个答案:

没有答案
相关问题