我正在编写损失函数,假设损失是针对双精度计算的,而不是针对张量计算的。这是我的功能:
def prediction_loss(a,b):
IGNORE=.0025
EPSILON=.00001
if IGNORE > abs(a) and IGNORE > abs(b) and np.sign(a)==np.sign(b):
return 0
scale=min(abs(a),abs(b))
distance=abs(a-b)
if abs(scale)<EPSILON:
scale=max(abs(a),abs(b))
if abs(scale)<EPSILON:
scale=1
distance**=2
return min(distance,distance/scale)
在model.compile中使用它时,出现以下错误:
OperatorNotAllowedInGraphError Traceback (most recent call last)
<ipython-input-44-92af3f50a682> in <module>()
9 keras.layers.Dense(1)
10 ])
---> 11 model.compile(loss=prediction_loss, optimizer=keras.optimizers.SGD(lr=0.001, momentum=0.9, nesterov=True))
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
371
372 # Creates the model loss and weighted metrics sub-graphs.
--> 373 self._compile_weights_loss_and_weighted_metrics()
374
375 # Functions for train, test and predict will
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _compile_weights_loss_and_weighted_metrics(self, sample_weights)
1651 # loss_weight_2 * output_2_loss_fn(...) +
1652 # layer losses.
-> 1653 self.total_loss = self._prepare_total_loss(masks)
1654
1655 def _prepare_skip_target_masks(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _prepare_total_loss(self, masks)
1711
1712 if hasattr(loss_fn, 'reduction'):
-> 1713 per_sample_losses = loss_fn.call(y_true, y_pred)
1714 weighted_losses = losses_utils.compute_weighted_loss(
1715 per_sample_losses,
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/losses.py in call(self, y_true, y_pred)
219 y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
220 y_pred, y_true)
--> 221 return self.fn(y_true, y_pred, **self._fn_kwargs)
222
223 def get_config(self):
<ipython-input-43-4630edd6290a> in prediction_loss(a, b)
14 EPSILON=.00001
15
---> 16 if IGNORE > abs(a) and IGNORE > abs(b) and np.sign(a)==np.sign(b):
17 return 0
18
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in __bool__(self)
763 `TypeError`.
764 """
--> 765 self._disallow_bool_casting()
766
767 def __nonzero__(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in _disallow_bool_casting(self)
532 else:
533 # Default: V1-style Graph execution.
--> 534 self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
535
536 def _disallow_iteration(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in _disallow_in_graph_mode(self, task)
521 raise errors.OperatorNotAllowedInGraphError(
522 "{} is not allowed in Graph execution. Use Eager execution or decorate"
--> 523 " this function with @tf.function.".format(task))
524
525 def _disallow_bool_casting(self):
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
显然,张量流将tf.Tensor作为参数a
和b
传递,并且这些逻辑不能在逻辑运算中使用。为使功能正常工作,我应该更改什么?我想忽略具有相同符号的小尺寸的a
和b
答案 0 :(得分:1)
是的,tf.tensor无法使用python bool。将keras.backend.switch()用于条件语句。
请参考以下链接:
它列出了所有可用于拟合等式的函数,例如更大,更大_,等于等。
使用keras后端功能更改语句“如果IGNORE> abs(a)和IGNORE> abs(b)和np.sign(a)== np.sign(b):”,它将解决您的问题。