我的模型太大,无法使用普通的v2 TPU设备获得大于64的批处理。在troubleshooting网站上,提到即将到来的tensorflow版本将支持bfloat16。新支持的tf版本1.9-1.12现在是否可以使用bfloat16,如果可以,我可以使用的优化器数量有限吗?我没有找到任何进一步的文档,但是在tensor2tensor模型中看到了bfloat16的用法,所以我想一定有办法。
此外,我读了TPU v3 supports bigger models as well,但该模型只需要进行最小的更改,但是我找不到任何需要更改的文档。
我已经在使用Adafactor并尝试缩小图层,如果您还有其他缩小技巧,那也很好。我使用图片矩阵和字向量(截至目前为float32)作为输入。
答案 0 :(得分:1)
您可以将bfloat16
与TPU一起使用。有两件事要做:
以下是说明必要更改的代码段:
def input_fn():
def dataset_parser(self, value):
"""Parse an ImageNet record from a serialized string Tensor."""
image = self.image_preprocessing_fn(
image_bytes=image_bytes,
is_training=self.is_training,
)
if self.use_bfloat16:
image = tf.cast(image, tf.bfloat16)
return image, label
def resnet_model_fn(features, labels, mode, params):
"""The model_fn for ResNet to be used with TPUEstimator."""
# This nested function allows us to avoid duplicating the logic which
# builds the network, for different values of --precision.
def build_network():
network = resnet_model.resnet_v1(
resnet_depth=FLAGS.resnet_depth,
num_classes=LABEL_CLASSES,
data_format=FLAGS.data_format)
return network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if FLAGS.precision == 'bfloat16':
with bfloat16.bfloat16_scope():
logits = build_network()
logits = tf.cast(logits, tf.float32)
elif FLAGS.precision == 'float32':
logits = build_network()
您还可以看到this TPU model中所示的第二个条件。