如何在pytorch中加载tfrecord?

时间:2019-04-29 01:48:29

标签: tensorflow pytorch tfrecord

如何在pytorch中使用tfrecord?

我已经下载了具有视频级功能的“ Youtube8M”数据集,但它存储在tfrecord中。 我试图从这些文件中读取一些示例,以将其转换为numpy,然后加载到pytorch中。但是失败了。

    reader = YT8MAggregatedFeatureReader()
    files = tf.gfile.Glob("/Data/youtube8m/train*.tfrecord")
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=5, shuffle=True)
    training_data = [
        reader.prepare_reader(filename_queue) for _ in range(1)
    ]

    unused_video_id, model_input_raw, labels_batch, num_frames = tf.train.shuffle_batch_join(
        training_data,
        batch_size=1024,
        capacity=1024 * 5,
        min_after_dequeue=1024,
        allow_smaller_final_batch=True  ,
        enqueue_many=True)

    with tf.Session() as sess:
        label_numpy = labels_batch.eval()
        print(type(label_numpy))

但是此步骤没有结果,只是停留了很长时间而没有任何响应。

4 个答案:

答案 0 :(得分:0)

也许这可以帮助您:TFRecord reader for PyTorch

答案 1 :(得分:0)

您可以使用DALI库直接在PyTorch代码中加载tfrecord。

您可以了解如何in their documentation

答案 2 :(得分:0)

一种解决方法是使用tensorflow 1.1 *急切模式或tensorflow 2+遍历数据集(因此您可以使用var len功能,使用buckets窗口),然后 torch.as_tensor(val.numpy()).to(device)用于火炬。

答案 3 :(得分:0)

我做了这个:

class LiTS(torch.utils.data.Dataset):

    def __init__(self, filenames):
        self.filenames = filenames

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        volume, segmentation = None, None
        if idx >= len(self):
            raise IndexError()
        ds = tf.data.TFRecordDataset(filenames[idx:idx+1])
        for x, y in ds.map(read_tfrecord):
            volume = torch.from_numpy(x.numpy())
            segmentation = torch.from_numpy(y.numpy())

        return volume, segmentation