我正在使用4个特定于类的自动编码器(3层前馈)的网络,并且在训练迭代中,有一个案例检查来决定哪个自动编码器必须更新:
def f(k): return tf.train.AdamOptimizer(learning_rate=lernrate).minimize(Cost_List[k]), n_List[k].assign_add(1.0), Cost_List[k]
def g(): ???
nothing = g()
min_index = tf.argmin(Cost_List, 0)
Case_0 = (tf.equal(min_index,0), lambda: f(0))
Case_1 = (tf.equal(min_index,1), lambda: f(1))
Case_2 = (tf.equal(min_index,2), lambda: f(2))
Case_3 = (tf.equal(min_index,3), lambda: f(3))
Case_List = [Case_0, Case_1, Case_2, Case_3]
[optimizer, update, cost] = tf.case(Case_List, nothing)
在这种情况下,没有条件得到满足,不应该做任何事情。在这种情况下,四个案例中的一个将被实现,所以这里没有实际问题,但我想修改代码,然后重要的是,在默认情况下将跳过训练样本。问题是,f_default和所有其他返回类型的返回类型必须相同,因为sess.run([optimizer,update,cost])期望某种类型。我怎么能这样做,在默认情况下真的什么都没发生?我已经尝试过使用tf.no_op(),但这不起作用......
谢谢,
Meridius
答案 0 :(得分:1)
要使签名匹配,您可以按如下方式定义g()
:
def g():
return tf.no_op(), tf.no_op(), tf.constant(0.0)
请注意,将g
直接传递为f_default
(而不是像当前代码那样传递g()
)会稍微高效,但行为应该相同。 / p>