Tensorflow:带有参数化可调用的tf.case,在for循环中定义的case列表

时间:2017-04-06 14:24:55

标签: python lambda machine-learning tensorflow

我正在尝试在训练循环中为自动编码器集合实现一个案例分支:根据某个条件,只应更新一个特定的自动编码器。我一直试图通过使用tf.case()来实现这一点,但它没有像我预期的那样工作......

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), lambda: f(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

winner_index:要更新的自动编码器索引

f(k_win):返回特定AE索引的所有更新callables

Case_List:包含布线和参数化函数对

Execution_List:在执行循环中为sess.run()调用。

for循环中的参数k应该定义Case_List,特别是'lambda:f(k)',但似乎在构建列表后,所有'lambda:f(k)'都设置为last k = N_Classes-1:效果是,只更新最后一个自动编码器,而不是那个带有'winner_index'的自动编码器。有谁有任何想法,这里发生了什么......?

感谢。

1 个答案:

答案 0 :(得分:1)

问题是你定义的lambda是使用全局变量k,在调用函数时,它具有循环中的最后一个值(N_Class - 1

一个更简单的例子:

lst = []
for k in range(10):
    lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])

给出:

[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]

而不是:

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

This answer更好地解释了这个问题,并指出了一些方法来克服这个问题。在您的情况下,您可以执行以下操作:

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))