减少内存的Tensorflow TPU v2 / v3 bfloat16

时间:2018-11-24 13:52:17

标签: python tensorflow google-compute-engine google-cloud-tpu

我的模型太大,无法使用普通的v2 TPU设备获得大于64的批处理。在troubleshooting网站上,提到即将到来的tensorflow版本将支持bfloat16。新支持的tf版本1.9-1.12现在是否可以使用bfloat16,如果可以,我可以使用的优化器数量有限吗?我没有找到任何进一步的文档,但是在tensor2tensor模型中看到了bfloat16的用法,所以我想一定有办法。

此外,我读了TPU v3 supports bigger models as well,但该模型只需要进行最小的更改,但是我找不到任何需要更改的文档。

我已经在使用Adafactor并尝试缩小图层,如果您还有其他缩小技巧,那也很好。我使用图片矩阵和字向量(截至目前为float32)作为输入。

1 个答案:

答案 0 :(得分:1)

您可以将bfloat16与TPU一起使用。有两件事要做:

  1. 在输入管道中将输入投射到bfloat16
  2. 在bfloat16范围内环绕网络,并将输出转换为F32以进行进一步的计算。

以下是说明必要更改的代码段:

def input_fn():

  def dataset_parser(self, value):
    """Parse an ImageNet record from a serialized string Tensor."""
    image = self.image_preprocessing_fn(
        image_bytes=image_bytes,
        is_training=self.is_training,
    )

    if self.use_bfloat16:
      image = tf.cast(image, tf.bfloat16)

    return image, label


def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator."""

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=LABEL_CLASSES,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

您还可以看到this TPU model中所示的第二个条件。