为什么这个简单模型在tensorflow中占用大量内存?

时间:2018-06-05 01:41:14

标签: tensorflow

我正在尝试实现一个简单的图像分割代码。我只拍了两张尺寸为50 x 512 x 512(深度,高度,宽度)的图像和两张与相应图像尺寸相同的地面实况。但是,当我训练一个只有两个图像和基本事实的简单模型时。 RAM内存使用量需要大约6GB。而且当我将小批量的大小设置为2到10时,它会出现完整的RAM内存(看起来像内存泄漏)。我不明白为什么这么简单的代码会有内存问题。

这是代码

import tensorflow as tf
import SimpleITK as sitk
import numpy as np
from scipy.ndimage import zoom

tf.logging.set_verbosity(tf.logging.INFO)

def tnet(inputs):
    conv1 = tf.layers.conv3d(inputs, 16, 5, padding='same')
    conv1 = tf.nn.relu(conv1)
    logits = tf.layers.conv3d(conv1, 1, 1, padding='same')
    logits = tf.reshape(logits, [-1, 1])
    logits = tf.nn.softmax(logits)
    logits = tf.reshape(logits, [-1, 64, 128, 128, 1])
    return logits

def dice_coef(logits, labels):
    logits = tf.reshape(logits, [-1, 64 * 128 * 128 * 1])
    labels = tf.reshape(labels, [-1, 64 * 128 * 128 * 1])

    x = 2 * tf.reduce_sum(tf.multiply(logits, labels), axis=-1)
    y = tf.reduce_sum(tf.multiply(logits, logits) + tf.multiply(labels, labels), axis=-1)
    z = tf.div(x, y)

    return tf.reduce_mean(z)

def dice_loss(logits, labels):
    return -dice_coef(logits, labels)

def loadTrainData():
    imageList = [
        '../data/train/Case00.mhd', '../data/train/Case01.mhd',
    ]
    GTList = [
        '../data/train/Case00_segmentation.mhd', '../data/train/Case01_segmentation.mhd',
    ]

    sitkImages = dict()

    rescalFilt = sitk.RescaleIntensityImageFilter()
    rescalFilt.SetOutputMaximum(1)
    rescalFilt.SetOutputMinimum(0)

    stats = sitk.StatisticsImageFilter()
    m = 0.

    for f in imageList:
        sitkImages[f] = rescalFilt.Execute(sitk.Cast(sitk.ReadImage(f), sitk.sitkFloat32))
        stats.Execute(sitkImages[f])
        m += stats.GetMean()

    sitkGT = dict()

    for f in GTList:
        sitkGT[f] = sitk.Cast(sitk.ReadImage(f), sitk.sitkFloat32)

    X_ = sorted(sitkImages.items())
    y_ = sorted(sitkGT.items())

    X_ = [sitk.GetArrayFromImage(d[1]) for d in X_]
    y_ = [sitk.GetArrayFromImage(l[1]) for l in y_]

    X = []
    y = []

    # SimpleITK.GetArrayFromImage() converts SimpleITK image to numpy
    for img in X_:
        X.append(zoom(img, (64 / img.shape[0], 128 / img.shape[1], 128 / img.shape[2])))
    for gt in y_:
        y.append(zoom(gt, (64 / gt.shape[0], 128 / gt.shape[1], 128 / gt.shape[2])))

    print("resized image shape : %s" % str(X[0].shape))
    print("resized gt shape : %s" % str(y[0].shape))

    return X, y

def preproc(images, labels):
    X = np.asarray(images, dtype=np.float32)
    y = np.asarray(labels, dtype=np.float32)

    print("all images shape : %s" % str(X.shape))
    print("all gts shape : %s" % str(y.shape))

    X = np.reshape(X, (-1, 64, 128, 128, 1))
    #    y = np.reshape(y, (-1, 64, 128, 128, 2))
    y = np.reshape(y, (-1, 64, 128, 128, 1))

    return X, y

def main(args):

    inputs = tf.placeholder(tf.float32, [None, 64, 128, 128, 1])
    labels = tf.placeholder(tf.float32, [None, 64, 128, 128, 1])

    logits = tnet(inputs)

    cost = dice_loss(logits=logits, labels=labels)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        epochs = 10

        for epoch in range(epochs):
            print('yeah')

            # load and preprocess images
            X_, y_ = loadTrainData()
            X, y = preproc(X_, y_)

            _, cost_val = sess.run(
                [optimizer, cost],
                feed_dict={
                    inputs: X,
                    labels: y
                }
            )

            print('cost : ' + str(cost_val))

if __name__ == '__main__':
    tf.app.run()

1 个答案:

答案 0 :(得分:2)

问题在于卷积操作。如果您有一个32x32的图像作为输入,并且您使用16个输出通道conv1 = tf.layers.conv3d(inputs, 16, 5, padding='same'),这将生成32x32x16形状的输出,您创建的卷积越多,参数和数据将添加到您的网络中图表,使用以下配置声明您的会话,它将显示每个层消耗多少内存。

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 
                                        log_device_placement=True))