定义一个函数在运行阶段之前遍历张量

时间:2018-10-19 08:45:04

标签: python tensorflow tensor

我正在定义一个损失函数,在其中必须迭代张量的值。当然,这是在训练阶段之前编码的python函数,我正在努力如何在张量上定义for循环。在这个张量中,有一些零和其他数字,但我不知道有多少,这取决于当前批次的训练文件。可能是2、5、10,...我不知道,所以我不能使用固定值。这是一个例子

points = tf.constant([[[0], [1], [0], [2], [1], [0], [2], [2], [0], [1]],
                      [[1], [2], [0], [0], [1], [0], [1], [2], [0], [2]],
                      [[0], [2], [1], [0], [2], [0], [1], [2], [0], [1]]], dtype=tf.float32)
max_indices = tf.reduce_max(points[:1])
for index in xrange(max_indices): # error here
    # do stuff

这是错误

TypeError: an integer is required

所以我尝试了另一种方法

points = tf.constant([[[0], [1], [0], [2], [1], [0], [2], [2], [0], [1]],
                      [[1], [2], [0], [0], [1], [0], [1], [2], [0], [2]],
                      [[0], [2], [1], [0], [2], [0], [1], [2], [0], [1]]], dtype=tf.float32)
items, _ = tf.unique(tf.reshape(points[:1], [-1]))
for item in tf.unstack(items): # error here
    # do stuff

错误是

ValueError: Cannot infer num from shape (?,)

当然,给出这些错误是因为在定义阶段我没有值,但我不知道如何解决它。

1 个答案:

答案 0 :(得分:1)

基于“在定义阶段我没有值”的假设,当您需要迭代数据集时,您正在会话中。

您可以使用tf.while_loop(cond, body, loop_vars)并将结束条件添加到每个张量的末尾。这应该是一些不会在数据集中显示的值。例如,如果您的数据集仅包含正实数,则可以使用-1作为结束条件。

基本上,tf.while_loop(cond, body, loop_vars)首先调用cond,这应该返回布尔张量。同时,如果cond不返回false,则body将被调用。请注意,您要在循环中使用的所有内容都应包含在loop_vars中。全局变量和loop_vars上下文之外的变量将不起作用。

因此,您可以使用loop_vars作为时间步长,并在每次通过t = tf.constant(0)时将其递增。然后,在body中,可以通过使用cond在每次迭代中获取值来检查数据集x是否等于最终条件。如果等于此结束条件,则说明您已对张量进行了完全迭代,并且x[t]应该返回false。否则,cond应该返回true,然后可以通过在cond中使用tx用作某些数据集x[t]的迭代器。

这就是我用来解决RNN中可变长度语句的迭代问题的方法。我以body作为结束条件,并且效果很好。我希望这也对您有用!