我是TensorFlow的新手,我正在尝试在一个函数中打印一个向量的形状,该函数将从TensorFlow会话中调用。
问题是这条线(显示已注释掉)仅在最初定义此函数模板时执行(而不是在TensorFlow会话期间的每次迭代)。如何添加一个print语句,以便在每次TensorFlow迭代时调用它?
def Q(X):
# f_debug.write('Q(X) :: X.shape :: ' + str(X.shape) + '\n')
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
答案 0 :(得分:1)
这是一个值得注意的重点,也是TF中常见的混淆。该函数不会被会话中的tensorflow调用,除了tf.py_func
之外,没有python函数可以解决您的问题。
Tensorflow仅调用函数Q
来获取符号操作,然后将这些操作添加到依赖关系图中。在会话期间,依赖图是依赖于执行计算的全部内容。即使您使用tf.while
,tf.cond
或其他控制流操作。这些都不会在会话期间调用python,它们只是按照您定义的方式遍历依赖关系图中的元素。
一般来说,如果没有使用Tensorflow调试器(根本不难配置),就没有好办法停止执行张量流中图执行。但作为一种解决方法,您可能会定义一个tf.py_func
python函数。这个函数将一个张量编组成一个python对象,并在会话执行期间调用python(它没有效率或任何东西,但在某些情况下它很方便)。
您可能需要使用with tf.control_dependencies(...):
来强制运行tf.py_func
操作(因为如果它只有一个打印语句,它就不会有任何依赖)。
Disclamer:我没有以这种方式使用tf.py_func
,也没有按照这个意图建立。
答案 1 :(得分:0)
这是 TensorFlow 1 的具体解释。我确信 TF 2 中的 Eager Execution 改变了其中的一些事情。请参阅docs for tf.print。
我在 py_func 函数中使用常规 print()
语句没有任何运气。您可以使用 tf.print
在执行 TF 图(py_func 或其他)的“内部”执行打印语句。我只熟悉它在 TF 1 中的使用,但它的工作原理是创建一个新的打印操作并将其添加到 TF 图中:
def Q(X):
# Note that at least sometimes, X.shape can be resolved before the
# graph is executed, so you may only need this for the *value* of X
# Create a new print op
printop = tf.print('Q:', X)
# Force printop to be added to the graph by setting it as a
# dependency for at least one operation that will be run
with tf.control_dependencies([printop]):
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
我遇到过一些情况,这并没有像我希望的那样工作(我相信是由于多线程)。如果您找不到图表的节点来附加 printop
操作,这也可能会很棘手——例如,如果您试图在函数的最后一行打印某些内容。
但这似乎对我来说大部分时间都有效。