张量流中是否有一种方法可以每次加载成批数据?

时间:2020-02-10 01:53:54

标签: python tensorflow tensorflow2.0 tensorflow-datasets

所以我在google colab中运行tensorflow 2+ python。

我的每个数据文件都是一个形状为[563,563,563,1]的3d图像,因此加载它们都将引发资源耗尽错误。

我花了几天和时间寻找一种方法来仅将我的数据集的一部分加载为张量,并在每次迭代时卸载/加载新的批次。我猜可能有一种使用tf.data.Dataset.list_files的方法,但是我找不到确切的方法。

是否有任何好的建议或建议阅读的文档?我已经从tensorflow中读取了tf.data文档,但是找不到我需要的信息。

谢谢!

编辑

所以这是我要用来加载图像的功能

def load_image(ind):
    file_brain = "/content/drive/My Drive/brain/" + str(ind) + ".mgz"
    file_mask = "/content/drive/My Drive/mask/" + str(ind) + ".mgz"
    data_brain, affine = load_nifti(file_brain)
    data_mask, affine = load_nifti(file_mask)
    data_brain = affine_transform(data_brain, affine)
    data_mask = affine_transform(data_mask, affine)
    data_brain = normalize(data_brain)
    data_brain = zoom(data_brain, (563/256, 563/256, 563/256))
    data_brain = tf.expand_dims(data_brain, axis=-1)
    data_mask = tf.expand_dims(data_mask, axis=-1)
    return data_brain, data_mask

这就是我之前加载数据集的方式,这耗尽了资源;

def create_dataset():
    train_data = []
    train_label = []
    test_data = []
    test_label = []
    test_n = np.random.randint(1, 10, 1)
    for i in range(1, 10):
        data_brain, data_mask = load_image(i)
        if i in test_n:
            test_data.append(data_brain)
            test_label.append(data_mask)
            continue
        train_data.append(data_brain)
        train_label.append(data_mask)
        shifted_data = data_brain + tf.random.uniform(shape=(), minval=-0.05, maxval=0.05)
        scaled_data = data_brain * tf.random.uniform(shape=(), minval=0.85, maxval=1.3)
        train_data.append(shifted_data)
        train_label.append(data_mask)
        train_data.append(scaled_data)
        train_label.append(data_mask)
"""
train_data = tf.data.Dataset.from_tensor_slices(train_data)
train_label = tf.data.Dataset.from_tensor_slices(train_label)
test_data = tf.data.Dataset.from_tensor_slices(test_data)
test_label = tf.data.Dataset.from_tensor_slices(test_label)
return train_data, train_label, test_data, test_label
"""

0 个答案:

没有答案