如何在Keras中用float16微调resnet50?

时间:2019-12-02 08:05:48

标签: python keras

我试图以半精度模式微调resnet50,但没有成功。似乎模型的某些部分与float16不兼容。这是我的代码:

dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)

model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))

我得到这个错误:

Traceback (most recent call last):
  File "train_resnet.py", line 40, in <module>
    model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
    return base_fun(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
    return resnet50.ResNet50(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
    output = self.call(inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
    epsilon=self.epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
    epsilon=epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
    data_format=tf_data_format)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
    param_name=input_name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32

1 个答案:

答案 0 :(得分:1)

这是reported bug,升级到Keras==2.2.5解决了该问题。