TensorFlow - TF记录太大而无法一次加载到np数组中

时间:2017-12-09 19:15:05

标签: tensorflow tensorflow-datasets tfrecord

我正在尝试按照in the tutorial from the TensorFlow guide site步骤训练AlexNet CNN模型。但是,本教程使用以下代码加载训练数据

mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

对我来说,我编写了一个脚本来将我的数据集示例写入TFRecord文件,然后在训练期间,尝试将这些记录读回并将其提供给alexnet网络。请参阅以下代码:

#FUNCTION TO GET ALL DATASET DATA 
def _read_multiple_images(filenames, perform_shuffle=False, repeat_count=1, 
batch_size=1, available_record=39209, num_of_epochs=1):
    def _read_one_image(serialized):
        #Specify the fatures you want to extract
        features = {'image/shape': tf.FixedLenFeature([], tf.string),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
            'image/class/text': tf.FixedLenFeature([], tf.string),
            'image/filename': tf.FixedLenFeature([], tf.string),
            'image/encoded': tf.FixedLenFeature([], tf.string)} 
        parsed_example = tf.parse_single_example(serialized, 
        features=features)

        #Finese extracted data
        image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8)
        shape = tf.decode_raw(parsed_example['image/shape'], tf.int32)
        label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32)
        reshaped_img = tf.reshape(image_raw, shape)
        casted_img =  tf.cast(reshaped_img, tf.float32)
        label_tensor= [label]
        image_tensor = [casted_img]
        return label_tensor, image_tensor

complete_labels = np.array([])
complete_images = np.array([])

dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(_read_one_image)
dataset = dataset.repeat(repeat_count)      #Repeats dataset this # times
dataset = dataset.batch(batch_size)         #Batch size to use
iterator = dataset.make_initializable_iterator()
labels_tensor, images_tensor = iterator.get_next() #Get batch data
no_of_rounds = int(math.ceil(available_record/batch_size));

#Create tf session, get nest set of batches, and evelauate them in batches
sess = tf.Session()
 count=1
for _ in range(num_of_epochs):
  sess.run(iterator.initializer)

  while True:
    try:
      evaluated_label, evaluated_image = sess.run([labels_tensor, 
       images_tensor])

      #convert evaluated tensors to np array 
      label_np_array = np.asarray(evaluated_label, dtype=np.uint8)
      image_np_array = np.asarray(evaluated_image, dtype=np.uint8)

      #squeeze np array to make dimesnsions appropriate
      squeezed_label_np_array = label_np_array.squeeze()
      squeezed_image_np_array = image_np_array.squeeze()

      #add current batch to total
      complete_labels = np.append(complete_labels, squeezed_label_np_array)
      complete_images = np.append(complete_images, squeezed_image_np_array)
      except tf.errors.OutOfRangeError:
      print("End of Dataset Reached")
      break
    count=count+1

sess.close()
return complete_labels, complete_images

我的主要问题是,在将我的数据集(227x227x3)中的所有39209图像恢复为np数组时,我可以将其提供给我的TF估算器。我的电脑内存不足。

train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": 
complete_images},y=complete_labels,batch_size=100,num_epochs=1, 
shuffle=True)
dataset_classifier.train(input_fn=train_input_fn,num_epochs=1,hooks=
[logging_hook])

有没有办法可以批量从我的TF记录中取出我的图像和标签,然后批量将其提供给我的TF.Estimator,而不必将其全部加载到指定的np数组中{ {3}}

1 个答案:

答案 0 :(得分:3)

如果您可以tf.data.Dataset访问数据,则无需在将数据传递给Estimator之前将其转换为NumPy数组。您可以直接在输入函数中构建Dataset,具体如下:

def train_input_fn():
  dataset = tf.data.TFRecordDataset(filenames=filenames)
  dataset = dataset.map(_read_one_image)
  dataset = dataset.repeat(1)  # Because `num_epochs=1`.
  dataset = dataset.batch(100)  # Because `batch_size=1`.

  dataset = dataset.prefetch(1)  # To improve performance by overlapping execution.

  iterator = dataset.make_one_shot_iterator()  # NOTE: Use a "one-shot" iterator.
  labels_tensor, images_tensor = iterator.get_next()

  return {"x": images_tensor}, labels_tensor

dataset_classifier.train(
    input_fn=train_input_fn, num_epochs=1, hooks=[logging_hook])

这应该比构建NumPy数组更有效,因为它避免了必须立即在内存中实现整个数据集。您还可以使用Dataset.prefetch()等性能增强功能和Dataset.map()的并行版本来提高培训速度。