tf.py_func()使用lambda函数在循环中意外输出

时间:2018-03-29 12:00:33

标签: python tensorflow

我正在观察Numpy和纯Tensorflow的不同行为,当实现相同的简单功能,在tf.py_func的for循环中的迭代中共享一个变量。

让我们从纯Numpy版本开始:

def my_func(x, k):
    return np.tile(x,k)

x = np.ones((1), np.int64)
for i in range(1,3):
    x = my_func(x, i)

print(x)

这会产生预期的输出。最初x[1]。在第一次迭代中,它被复制一次以生成[1]。然后在下一次迭代中,结果被复制两次,产生最终输出[1 1]

类似的方法也会在纯Tensorflow中产生相同的预期输出:

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.tile(x, [i])

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)

输出为[1 1]

现在我试图使用tf.py_func基本上做同样的事情,我无法理解为什么我看到不同的输出。这段代码:

import tensorflow as tf
import numpy as np

def my_func(x, k):
    return np.tile(x,k)

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.py_func(lambda y: my_func(y, i), [x], tf.int64)

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)

产生意外结果[1 1 1 1]

为什么会这样? py_func是否有一些属性,它对共享(张量)变量名称不起作用,在这种情况下是在每次循环迭代时更新的变量x

请注意,这是一个简化的示例,可以重现问题,并且其功能很容易在纯Tensorflow中重现。在我的实际应用程序中,需要使用tf.py_func,因为功能更复杂。

1 个答案:

答案 0 :(得分:2)

没有lambda函数,它按预期工作:

import tensorflow as tf
import numpy as np

def my_func(x, k):
    return np.tile(x,k)

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.py_func(my_func, [x, i], tf.int64)

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)

返回[1 1]

修改

我找到了原因:lambda y: my_func(y, i)通过引用而非值来保存我。因此,for循环的最后一个i值将应用于循环中的所有py_func。这是一个更简单的示例,显示了问题:

import tensorflow as tf

def my_func(x, y):
  return x - y

x1 = tf.constant([0], tf.float32)
for i in range(2):
    x1 = tf.py_func(lambda y: my_func(y, i), [x1], tf.float32)

x2 = tf.constant([0], tf.float32)
x2 = tf.py_func(lambda y: my_func(y, 0), [x2], tf.float32)
x2 = tf.py_func(lambda y: my_func(y, 1), [x2], tf.float32)

with tf.Session() as sess:
    print(sess.run(x1))
    print(sess.run(x2))