我正在尝试使用tf.cond()
根据条件创建2个不同的图。在两个图上,我们都希望有权重正则化损失,因此我们使用tf.losses.get_regularization_loss()
。这是我们项目的伪代码
def net_1(x,y):
statement 1 (has trainable params)
statement 2 (has trainable params)
return
def net_2(x,y):
statement 1 (has trainable params)
statement 2 (has trainable params)
statement 3 (has trainable params)
return
step = tf.get_or_create_global_step()
tf.cond(tf.greater(step, 100), net_1, net_2)
loss = 0.0
loss += tf.losses.get_regularization_loss()
如果我们保留tf.losses.get_regularization_loss()
,则会收到错误消息:
Retval [0]没有值
否则,没有错误。
如果我们要强行使用tf.cond()
,是否需要特别注意tf.losses.get_regularization_loss()
。
答案 0 :(得分:0)
您的伪代码还不太清楚,但是,tf.cond
期望张量作为参数,根据您编写的内容,您可以为其提供功能。如果您的函数net_1
和net_2
返回张量(它们显然应该这样做),请在tf.cond
调用中使用输出张量,例如如下:
tf.cond(tf.greater(step, 100), net_1(x, y), net_2(x, y))
答案 1 :(得分:0)
同一问题,我通过用两个相似的函数替换tf.cond(与正则化有关)来解决它……现在找不到更好的解决方案。