使用lambda时tf.case获得意外结果

时间:2018-10-25 07:39:39

标签: python tensorflow

看这个例子。

import tensorflow as tf

tf.reset_default_graph()
LENGTH = 25
M_list = []
for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[1], initializer=tf.constant_initializer(i)))

choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
case_set = [(tf.equal(choose_mat[i], 1), lambda: M_list[i]) for i in range(LENGTH)]
M = tf.case(case_set)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    CM1 = [0] * LENGTH
    CM1[0] = 1
    CM2 = [0] * LENGTH
    CM2[1] = 1

    m1 = sess.run(M, feed_dict={choose_mat: CM1})
    m2 = sess.run(M, feed_dict={choose_mat: CM2})
    print(m1) # [24.]
    print(m2) # [24.]

    m1_ = sess.run(M_list[0])
    m2_ = sess.run(M_list[1])
    print(m1_) # [0.]
    print(m2_) # [1.]

我们期望m1,m2是0,1 但是我们有24。 而且M_list的结果是正确的,就像m1_和m2_一样,很奇怪。

尽管我已经修复了该错误(请参见我的答案),但是我仍然有一个问题,我不知道为什么这段代码会导致关闭,case_set没有任何功能,有人知道为什么这是关闭吗?

1 个答案:

答案 0 :(得分:0)

实际上,此错误不是由tensorflow引起的,真正的原因是python的关闭。 see this link 这样这段代码将获得预期的结果。

import tensorflow as tf

tf.reset_default_graph()
LENGTH = 25
M_list = []
for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[1], initializer=tf.constant_initializer(i)))

choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
case_set = [(tf.equal(choose_mat[i], 1), lambda i=i: M_list[i]) for i in range(LENGTH)]
M = tf.case(case_set)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    CM1 = [0] * LENGTH
    CM1[0] = 1
    CM2 = [0] * LENGTH
    CM2[1] = 1

    m1 = sess.run(M, feed_dict={choose_mat: CM1})
    m2 = sess.run(M, feed_dict={choose_mat: CM2})
    print(m1) # [0.]
    print(m2) # [1.]

尽管我已经修复了该错误,但我仍然不知道为什么这段代码会导致关闭,case_set没有任何功能,有人知道为什么这是关闭吗?