如何阻止tensorflow数据生成器消耗所有内存?

时间:2019-04-17 17:51:53

标签: keras tensorflow-datasets

我有一个Keras模型,该模型吸收一系列图像并进行预测。

#Input shape: (#samples,seq_length,image_height,image_width,1)
#Output shape: (#samples,seq_length,image_height,image_width,1)

我已将样本数量写入文件服务器。它占用的总内存约为1TB。

以下是我如何解决此问题的方法:

def gen_x():
  for filename in os.listdir("data_X"):
      with open ("data_X/"+filename, 'rb') as fp:
          X = pickle.load(fp)
      yield X

def gen_y():
  for filename in os.listdir("data_Y"):
      with open ("data_Y/"+filename, 'rb') as fp:
          Y = pickle.load(fp)
      yield Y      

ds_x=tf.data.Dataset.from_generator(gen_x, (tf.float32), (tf.TensorShape([50,128,64,1]))).batch(128)
ds_y=tf.data.Dataset.from_generator(gen_y, (tf.float32), (tf.TensorShape([5,128,64,1]))).batch(128)

it_x = ds_x.make_one_shot_iterator()
it_y = ds_y.make_one_shot_iterator()

但是,模型在开始训练时会消耗所有可用的内存,而不是仅为批处理分配内存。请注意,GPU内存仍未得到充分利用,但RAM已完全耗尽。以下是该模型的代码:

input_frames=Input(tensor=it_x.get_next())
x=ConvLSTM2D(filters=40, kernel_size=(3, 3), padding='same', return_sequences=False)(input_frames)
x=BatchNormalization()(x)
x=Conv2D(filters=5,kernel_size=(3,3),padding="same")(x)
x=Reshape((5,128,64,1))(x)
model=Model(inputs=input_frames,outputs=x)

model.compile(loss='binary_crossentropy', optimizer='adadelta',target_tensors=[it_y.get_next()])    
model.fit(steps_per_epoch=#Samples // 128, epochs=5, verbose=1)

0 个答案:

没有答案