我有一个tensorflow数据集,包含图像,标签和视频ID。我想通过在深度域中堆叠连续图像来时空处理图像。在我看来tf.dataset.Dataset.window
是正确的方法。但是,由于window方法返回的类型为_VariantDataset
,所以我不知道如何进行。我的尝试是尝试以下操作:
num_frames = 5
strides = 3
#here define the generator
def gen():
for fname, vid_id, label in zip(im_list, vid_ids, labels_list):
yield (fname, vid_id, label)
dataset = tf.data.Dataset.from_generator(gen, (tf.string, tf.int64, tf.float64), (tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([1])))
def _parse_example(fname, vid_id, label):
image = tf.io.read_file(fname)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [112, 112])
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.rgb_to_grayscale(image)
return image, vid_id, label
dataset = dataset.map(_parse_example)
#here create a sliding window
dataset_temporal = dataset.window(num_frames, shift=strides, drop_remainder=True)
def _stack_in_depth_fn(ims, vid_ids, labels):
stacked_ims = tf.expand_dims(tf.stack(ims, axis=-1), axis=0)
#the label is always the same but video_id will be used earlier to filter out non-consecutive frames
return stacked_ims, vid_ids[0], labels[0]
dataset_temporal = dataset_temporal.map(_stack_in_depth_fn)
执行上述操作时,出现以下错误
train.py:574 _stack_in_depth_fn *
stacked_ims = tf.expand_dims(tf.stack(ims, axis=-1), axis=0)
/home/diaa/tensorflow-gpu-env-2/lib/python3.7/site-packages/tensorflow_core/python/util/dispatch.py:180 wrapper
return target(*args, **kwargs)
/home/diaa/tensorflow-gpu-env-2/lib/python3.7/site-packages/tensorflow_core/python/ops/array_ops.py:1158 stack
value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple() # pylint: disable=protected-access
TypeError: '_VariantDataset' object is not subscriptable
我需要从5帧的滑动窗口中实现一个形状为[112,112,5]
的图像(每个图像为[112,112,1]
)。然后,我想创建大小为[None, 112,112,5]
的批次并将其输入到模型中。
有什么帮助吗?