具有随机张量的压缩张量流数据集的怪异行为

时间:2019-12-20 14:57:10

标签: tensorflow random tensorflow-datasets tensorflow2.0

在下面的示例(Tensorflow 2.0)中,我们有一个包含三个元素的虚拟tensorflow数据集。我们在其上映射一个函数(replace_with_float),该函数以两个副本的形式返回随机生成的值。正如我们期望的那样,当我们从数据集中获取元素时,第一个和第二个坐标具有相同的值。

现在,我们分别从第一个坐标和第二个坐标创建两个“切片”数据集,然后将两个数据集压缩在一起。切片和压缩操作似乎是相反的,因此我希望结果数据集等同于上一个数据集。但是,正如我们所看到的,现在第一和第二坐标是随机生成的不同值。

如果我们将“相同”数据集自身压缩到以下位置,可能会更有趣 df = tf.data.Dataset.zip((df.map(lambda x, y: x), df.map(lambda x, y: x))),两个坐标也将具有不同的值。

如何解释这种行为?也许要为要压缩的两个数据集构造两个不同的图,并且它们是独立运行的?

import tensorflow as tf

def replace_with_float(element):
    rand = tf.random.uniform([])
    return (rand, rand)

df = tf.data.Dataset.from_tensor_slices([0, 0, 0])
df = df.map(replace_with_float)
print('Before zipping: ')
for x in df:
    print(x[0].numpy(), x[1].numpy())

df = tf.data.Dataset.zip((df.map(lambda x, y: x), df.map(lambda x, y: y)))

print('After zipping: ')
for x in df:
    print(x[0].numpy(), x[1].numpy())

示例输出:

Before zipping: 
0.08801079 0.08801079
0.638958 0.638958
0.800568 0.800568
After zipping: 
0.9676769 0.23045003
0.91056764 0.6551999
0.4647777 0.6758332

1 个答案:

答案 0 :(得分:1)

简单的答案是,除非您使用df.cache()明确请求,否则数据集不会在完整迭代之间缓存中间值,并且它们也不会对常见输入进行重复数据删除。

因此,在第二个循环中,整个管道再次运行。 同样,在第二种情况下,两次df.map调用导致df运行两次。

添加tf.print有助于说明会发生什么情况:

def replace_with_float(element):
    rand = tf.random.uniform([])
    tf.print('replacing', element, 'with', rand)
    return (rand, rand)

我也将lambda拉到了不同的行,以避免签名警告:

first = lambda x, y: x
second = lambda x, y: y

df = tf.data.Dataset.zip((df.map(first), df.map(second)))
Before zipping: 
replacing 0 with 0.624579549
0.62457955 0.62457955
replacing 0 with 0.471772075
0.47177207 0.47177207
replacing 0 with 0.394005418
0.39400542 0.39400542

After zipping: 
replacing 0 with 0.537954807
replacing 0 with 0.558757305
0.5379548 0.5587573
replacing 0 with 0.839109302
replacing 0 with 0.878996611
0.8391093 0.8789966
replacing 0 with 0.0165234804
replacing 0 with 0.534951568
0.01652348 0.53495157

为避免重复输入问题,您可以使用单个map调用:

swap = lambda x, y: (y, x)
df = df.map(swap)

或者您可以使用df = df.cache()来避免两种影响:

df = df.map(replace_with_float)
df = df.cache()
Before zipping: 
replacing 0 with 0.728474379
0.7284744 0.7284744
replacing 0 with 0.419658661
0.41965866 0.41965866
replacing 0 with 0.911524653
0.91152465 0.91152465

After zipping: 
0.7284744 0.7284744
0.41965866 0.41965866
0.91152465 0.91152465