我想在张量流的tf.variable
内逐行更新二维tf.while_loop
。因此,我使用tf.assign
方法。问题是我的实现和parallel_iterations>1
的结果是错误的。使用parallel_iterations=1
时结果正确。代码是这样的:
a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)
i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)
def body(i, var):
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(a[i], updated_row)
return [tf.add(i, 1), temp]
z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)
迭代是完全独立的,我不知道出了什么问题。
奇怪的是,如果我这样更改代码:
a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)
i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)
def body(i, var):
zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(temp[i], updated_row)
return [tf.add(i, 1), temp]
z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)
代码为parallel_iterations>1
提供了正确的结果。有人可以解释一下这是怎么回事,并给我一个有效的解决方案来更新变量,因为我要更新的原始变量很大,而我发现的解决方案效率很低。
答案 0 :(得分:0)
您不需要为此使用变量,只需在循环主体上生成行更新的张量即可:
PackageInfo info = application.getPackageManager().getPackageInfo(application.getPackageName(), 0);
输出:
java.lang.NullPointerException
at digitu.com.osmos.ApplicationTest.testCorrectVersion(ApplicationTest.java:27)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at junit.framework.TestCase.runTest(TestCase.java:176)
at junit.framework.TestCase.runBare(TestCase.java:141)
at junit.framework.TestResult$1.protect(TestResult.java:122)
at junit.framework.TestResult.runProtected(TestResult.java:142)
at junit.framework.TestResult.run(TestResult.java:125)
at junit.framework.TestCase.run(TestCase.java:129)
at junit.framework.TestSuite.runTest(TestSuite.java:252)
at junit.framework.TestSuite.run(TestSuite.java:247)
at org.junit.internal.runners.JUnit38ClassRunner.run(JUnit38ClassRunner.java:86)
at org.junit.runner.JUnitCore.run(JUnitCore.java:137)
at com.intellij.junit4.JUnit4IdeaTestRunner.startRunnerWithArgs(JUnit4IdeaTestRunner.java:68)
at com.intellij.rt.execution.junit.IdeaTestRunner$Repeater.startRunnerWithArgs(IdeaTestRunner.java:47)
at com.intellij.rt.execution.junit.JUnitStarter.prepareStreamsAndStart(JUnitStarter.java:242)
at com.intellij.rt.execution.junit.JUnitStarter.main(JUnitStarter.java:70)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
答案 1 :(得分:0)
在tf.function中,我发现了以下内容:
要点:跟踪func时,任何Python副作用(追加到列表,使用打印进行打印等)都只会发生一次。要在您的tf.function中执行副作用,需要将其写为TF ops:
我很确定这就是发生的事情。您期望 a 发生变化,但这是一个“副作用”(https://runestone.academy/runestone/books/published/fopp/Functions/SideEffects.html),它没有完全支持张量流。当您将a更改为temp时,您不再依赖副作用,并且代码可以正常工作。