TPU Tensorflow 训练:重新加载模型以恢复训练在 TPU 上失败,但在 GPU 上失败

时间:2021-03-31 10:24:22

标签: tensorflow keras tensorflow2.0 tpu

我已经在 TPU 上训练了一个模型,一切都成功了。经过训练的模型按预期工作。

我想再次重新加载模型以进行微调 - 我通过以下方式重新加载:

self.model = tf.keras.models.load_model('detector/checkpoints/model.10',
                                 custom_objects={'SSDLoss': SSDLoss,
                                       'SmoothLoss': SmoothLoss})

如果我在 GPU 上执行此操作,则重新加载模型并恢复训练有效,但如果我在 TPU 上尝试执行相同操作,则会出现以下错误:

 File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1103, in fit
    callbacks.on_train_batch_end(end_step, logs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 440, in on_train_batch_end
    self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 289, in _call_batch_hook
    self._call_batch_end_hook(mode, batch, logs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 309, in _call_batch_end_hook
    self._call_batch_hook_helper(hook_name, batch, logs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 345, in _call_batch_hook_helper
    numpy_logs = tf_utils.to_numpy_or_python_type(logs)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 537, in to_numpy_or_python_type
    return nest.map_structure(_to_single_numpy_or_python_type, tensors)
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 635, in map_structure
    structure[0], [func(*x) for x in entries],
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 635, in <listcomp>
    structure[0], [func(*x) for x in entries],
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 533, in _to_single_numpy_or_python_type
    x = t.numpy()
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1063, in numpy
    maybe_arr = self._numpy()  # pylint: disable=protected-access
  File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1031, in _numpy
    six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: 9 root error(s) found.
  (0) Invalid argument: {{function_node __inference_train_function_96086}} Compilation failure: Detected unsupported operations when trying to compile graph __inference_cond_true_679310_21187_rewritten[] on XLA_TPU_JIT: OptionalFromValue (No registered 'OptionalFromValue' OpKernel for XLA_TPU_JIT devices compatible with node {{node OptionalFromValue}}){{node OptionalFromValue}}
         [[functional_1/perturb_layer/StatefulPartitionedCall/StatefulPartitionedCall/cond_0]]
        TPU compilation failed
         [[tpu_compile_succeeded_assert/_9040394868092548934/_15]]
         [[tpu_compile_succeeded_assert/_9040394868092548934/_15/_425]]

注意:root错误相同重复9次

知道错误意味着什么吗?或者是什么原因造成的?

我使用的是 Python 3.6.9 和 tf 2.3.0

0 个答案:

没有答案