TensorFlow:tf.case()& f_default应该什么都不做

时间:2017-01-28 13:01:19

标签: tensorflow

我正在使用4个特定于类的自动编码器(3层前馈)的网络,并且在训练迭代中,有一个案例检查来决定哪个自动编码器必须更新:

def f(k): return tf.train.AdamOptimizer(learning_rate=lernrate).minimize(Cost_List[k]), n_List[k].assign_add(1.0), Cost_List[k]

def g(): ???

nothing = g()



min_index = tf.argmin(Cost_List, 0) 

Case_0 = (tf.equal(min_index,0), lambda: f(0))
Case_1 = (tf.equal(min_index,1), lambda: f(1))
Case_2 = (tf.equal(min_index,2), lambda: f(2))
Case_3 = (tf.equal(min_index,3), lambda: f(3))


Case_List = [Case_0, Case_1, Case_2, Case_3]

[optimizer, update, cost] = tf.case(Case_List, nothing)

在这种情况下,没有条件得到满足,不应该做任何事情。在这种情况下,四个案例中的一个将被实现,所以这里没有实际问题,但我想修改代码,然后重要的是,在默认情况下将跳过训练样本。问题是,f_default和所有其他返回类型的返回类型必须相同,因为sess.run([optimizer,update,cost])期望某种类型。我怎么能这样做,在默认情况下真的什么都没发生?我已经尝试过使用tf.no_op(),但这不起作用......

谢谢,

Meridius

1 个答案:

答案 0 :(得分:1)

要使签名匹配,您可以按如下方式定义g()

def g():
  return tf.no_op(), tf.no_op(), tf.constant(0.0)

请注意,将g直接传递为f_default(而不是像当前代码那样传递g())会稍微高效,但行为应该相同。 / p>