如何解决数据获取瓶颈以进行TPU推理?

时间:2020-07-28 21:23:31

标签: tensorflow google-compute-engine tpu google-cloud-tpu

这就是我的推理设置的样子

autotune = tf.data.experimental.AUTOTUNE

with strategy.scope():
    model = LoadModel()
    raw_dataset = tf.data.TFRecordDataset(tfRecordAddress)
    train_dataset = raw_dataset.map(_parse_example, num_parallel_calls=autotune)
    train_dataset = train_dataset.padded_batch(batch_size, padding_values=(1, 1, b'-'), padded_shapes=(512, 512, 1))
    # train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.prefetch(autotune)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)

def per_core_inference_fn(inputIds,attnIds ):
    return model.inference((inputIds, attnIds))

@tf.function
def inference_fn(inputIds, attnIds):
    return strategy.run(per_core_inference_fn, args=(inputIds,attnIds))

results = []
for x in train_dataset:
    t0 = time.time()
    results.append(inference_fn(x[0], x[1]))
    t1 = time.time()
    print('time is :', t1-t0)

批处理大小很大,推理速度很快,大约为.0003秒。但是,下一批的提取要花费很长时间for x in train_dataset:,例如60-80秒。

据我所知,我正确地进行了推断,但是在某种程度上,TPU的CPU遇到了批量检索的巨大瓶颈。

在培训期间,我没有看到此瓶颈。因此,看来model.fit正在做我没有做的事情。

1 个答案:

答案 0 :(得分:1)

我觉得这个瓶颈是由for x in train_dataset引起的。批加载之间的60-80秒对我来说意味着预取未按预期工作。在自定义训练循环(CTL)代码中,我通常会看到整个循环都包裹在tf.function中,例如here中。

您可以类似地修改代码吗?您也可以尝试捕获TPU配置文件(https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_profile),而不是使用time.time()进行基准测试。