Retval [0]没有值:tf.cond(condition,net1,net2)

时间:2019-06-25 21:47:39

标签: tensorflow tensorflow-estimator

我正在尝试使用tf.cond()根据条件创建2个不同的图。在两个图上,我们都希望有权重正则化损失,因此我们使用tf.losses.get_regularization_loss()。这是我们项目的伪代码

 def net_1(x,y):
  statement 1 (has trainable params)
  statement 2 (has trainable params)
  return

def net_2(x,y):
 statement 1 (has trainable params)
 statement 2 (has trainable params)
 statement 3 (has trainable params)
 return
step = tf.get_or_create_global_step()


tf.cond(tf.greater(step, 100), net_1, net_2)
loss = 0.0
loss += tf.losses.get_regularization_loss()

如果我们保留tf.losses.get_regularization_loss(),则会收到错误消息:

  

Retval [0]没有值

否则,没有错误。

如果我们要强行使用tf.cond(),是否需要特别注意tf.losses.get_regularization_loss()

2 个答案:

答案 0 :(得分:0)

您的伪代码还不太清楚,但是,tf.cond期望张量作为参数,根据您编写的内容,您可以为其提供功能。如果您的函数net_1net_2返回张量(它们显然应该这样做),请在tf.cond调用中使用输出张量,例如如下:

tf.cond(tf.greater(step, 100), net_1(x, y), net_2(x, y))

答案 1 :(得分:0)

同一问题,我通过用两个相似的函数替换tf.cond(与正则化有关)来解决它……现在找不到更好的解决方案。