如何临时堆叠Tensorflow数据集的图像,然后对其进行批处理

时间:2019-12-11 23:44:00

标签: tensorflow tensorflow-datasets

我有一个tensorflow数据集,包含图像,标签和视频ID。我想通过在深度域中堆叠连续图像来时空处理图像。在我看来tf.dataset.Dataset.window是正确的方法。但是,由于window方法返回的类型为_VariantDataset,所以我不知道如何进行。我的尝试是尝试以下操作:

  num_frames = 5
  strides = 3 
  #here define the generator
  def gen():
    for fname, vid_id, label in zip(im_list, vid_ids, labels_list):
      yield (fname, vid_id, label)

    dataset = tf.data.Dataset.from_generator(gen, (tf.string, tf.int64, tf.float64), (tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([1])))
    def _parse_example(fname, vid_id, label):
      image = tf.io.read_file(fname)
      image = tf.image.decode_jpeg(image, channels=3)
      image = tf.image.resize(image, [112, 112])
      image = tf.image.convert_image_dtype(image, tf.float32)
      image = tf.image.rgb_to_grayscale(image)
      return image, vid_id, label
    dataset = dataset.map(_parse_example)
    #here create a sliding window
    dataset_temporal = dataset.window(num_frames, shift=strides, drop_remainder=True)


    def _stack_in_depth_fn(ims, vid_ids, labels):
      stacked_ims = tf.expand_dims(tf.stack(ims, axis=-1), axis=0)
      #the label is always the same but video_id will be used earlier to filter out non-consecutive frames
      return stacked_ims, vid_ids[0], labels[0]


    dataset_temporal = dataset_temporal.map(_stack_in_depth_fn)  

执行上述操作时,出现以下错误

    train.py:574 _stack_in_depth_fn  *
        stacked_ims = tf.expand_dims(tf.stack(ims, axis=-1), axis=0)
    /home/diaa/tensorflow-gpu-env-2/lib/python3.7/site-packages/tensorflow_core/python/util/dispatch.py:180 wrapper
        return target(*args, **kwargs)
    /home/diaa/tensorflow-gpu-env-2/lib/python3.7/site-packages/tensorflow_core/python/ops/array_ops.py:1158 stack
        value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple()  # pylint: disable=protected-access

    TypeError: '_VariantDataset' object is not subscriptable

我需要从5帧的滑动窗口中实现一个形状为[112,112,5]的图像(每个图像为[112,112,1])。然后,我想创建大小为[None, 112,112,5]的批次并将其输入到模型中。 有什么帮助吗?

0 个答案:

没有答案