如何加速在tensorflow_datasets中将张量转换为numpy数组的代码?

时间:2019-10-08 09:52:17

标签: python numpy tensorflow tensor tensorflow-datasets

尽管我想将tensorflow_datasets中的张量转换为numpy数组,但是我的代码却逐渐变慢。 现在,我使用lsun / bedroom数据集,其中包含超过300万张图像。 如何加速我的代码?

我的代码保存了每100,000张图像中包含numpy数组的元组。

train_tf = tfds.load("lsun/bedroom", data_dir="{$my_directory}", download=False)
train_tf = train_tf["train"]
for data in train_tf:
    if d_cnt==0 and d_cnt%100001==0:
        train = (tfds.as_numpy(data["image"]), )
    else:
        train += (tfds.as_numpy(data["image"]), )

    if d_cnt%100000==0 and d_cnt!=0:
        with open("{$my_directory}/lsun.pickle%d"%(d_cnt), "wb") as f:
            pickle.dump(train, f)

    d_cnt += 1

1 个答案:

答案 0 :(得分:1)

您的if条件永远不会在第一遍之后执行,因此您的train变量会不断累积。

我认为您希望具备以下条件:

if d_cnt!=0 and d_cnt%100001==0:
    train = (tfds.as_numpy(data["image"]), )