tf.while_loop并行运行时给出错误的结果

时间:2018-10-02 15:27:37

标签: python tensorflow

我想在张量流的tf.variable内逐行更新二维tf.while_loop。因此,我使用tf.assign方法。问题是我的实现和parallel_iterations>1的结果是错误的。使用parallel_iterations=1时结果正确。代码是这样的:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(a[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

迭代是完全独立的,我不知道出了什么问题。

奇怪的是,如果我这样更改代码:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
    temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(temp[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

代码为parallel_iterations>1提供了正确的结果。有人可以解释一下这是怎么回事,并给我一个有效的解决方案来更新变量,因为我要更新的原始变量很大,而我发现的解决方案效率很低。

2 个答案:

答案 0 :(得分:0)

您不需要为此使用变量,只需在循环主体上生成行更新的张量即可:

PackageInfo info = application.getPackageManager().getPackageInfo(application.getPackageName(), 0);

输出:

java.lang.NullPointerException
    at digitu.com.osmos.ApplicationTest.testCorrectVersion(ApplicationTest.java:27)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at junit.framework.TestCase.runTest(TestCase.java:176)
    at junit.framework.TestCase.runBare(TestCase.java:141)
    at junit.framework.TestResult$1.protect(TestResult.java:122)
    at junit.framework.TestResult.runProtected(TestResult.java:142)
    at junit.framework.TestResult.run(TestResult.java:125)
    at junit.framework.TestCase.run(TestCase.java:129)
    at junit.framework.TestSuite.runTest(TestSuite.java:252)
    at junit.framework.TestSuite.run(TestSuite.java:247)
    at org.junit.internal.runners.JUnit38ClassRunner.run(JUnit38ClassRunner.java:86)
    at org.junit.runner.JUnitCore.run(JUnitCore.java:137)
    at com.intellij.junit4.JUnit4IdeaTestRunner.startRunnerWithArgs(JUnit4IdeaTestRunner.java:68)
    at com.intellij.rt.execution.junit.IdeaTestRunner$Repeater.startRunnerWithArgs(IdeaTestRunner.java:47)
    at com.intellij.rt.execution.junit.JUnitStarter.prepareStreamsAndStart(JUnitStarter.java:242)
    at com.intellij.rt.execution.junit.JUnitStarter.main(JUnitStarter.java:70)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)

答案 1 :(得分:0)

在tf.function中,我发现了以下内容:

要点:跟踪func时,任何Python副作用(追加到列表,使用打印进行打印等)都只会发生一次。要在您的tf.function中执行副作用,需要将其写为TF ops:

我很确定这就是发生的事情。您期望 a 发生变化,但这是一个“副作用”(https://runestone.academy/runestone/books/published/fopp/Functions/SideEffects.html),它没有完全支持张量流。当您将a更改为temp时,您不再依赖副作用,并且代码可以正常工作。