谷歌ml-engine:永远填补队列

时间:2017-11-30 22:45:35

标签: google-cloud-ml

我创建了存储在Google存储分区上的tf records个文件。我在ml-engine上运行了一个代码,使用这些tf records

中的数据训练模型

每个tf记录文件包含一批20个示例,大小约为8Mb( Mega 字节)。存储桶中有数千个文件。

我的问题是,开始训练需要花费很长时间。我必须在加载软件包的时刻和培训实际开始的那一刻之间等待大约40分钟。我猜这是下载数据和填充队列所需的时间?

代码是(为了简洁而略微简化):

    # Create a queue which will produce tf record names
    filename_queue = tf.train.string_input_producer(files, num_epochs=num_epochs, capacity=100)

    # Read the record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # Map for decoding the serialized example
    features = tf.parse_single_example(
        serialized_example,
        features={
            'data': tf.FixedLenFeature([], tf.float32),
            'label': tf.FixedLenFeature([], tf.int64)
        })

    train_tensors = tf.train.shuffle_batch(
        [features['data'], features['label']],
        batch_size=30,
        capacity=600,
        min_after_dequeue=400,
        allow_smaller_final_batch=True
        enqueue_many=True)

我已经检查过我的存储桶和我的工作共享相同的region参数。

我不明白花了这么长时间:它应该只是下载几百Mbs(几十个tf记录文件应该足以拥有超过min_after_dequeue个元素队列)。

知道我错过了什么,或问题可能在哪里?

由于

1 个答案:

答案 0 :(得分:1)

对不起,我的不好。我使用自定义函数:

  1. 验证作为tf记录传递的每个文件是否确实存在。
  2. 展开外卡字符(如果有)
  3. 当在gs://

    上处理数千个文件时,这是一个非常坏主意

    我已经删除了这个“理智”检查,现在工作正常。