我已经在 TPU 上训练了一个模型,一切都成功了。经过训练的模型按预期工作。
我想再次重新加载模型以进行微调 - 我通过以下方式重新加载:
self.model = tf.keras.models.load_model('detector/checkpoints/model.10',
custom_objects={'SSDLoss': SSDLoss,
'SmoothLoss': SmoothLoss})
如果我在 GPU 上执行此操作,则重新加载模型并恢复训练有效,但如果我在 TPU 上尝试执行相同操作,则会出现以下错误:
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1103, in fit
callbacks.on_train_batch_end(end_step, logs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 440, in on_train_batch_end
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 289, in _call_batch_hook
self._call_batch_end_hook(mode, batch, logs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 309, in _call_batch_end_hook
self._call_batch_hook_helper(hook_name, batch, logs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 345, in _call_batch_hook_helper
numpy_logs = tf_utils.to_numpy_or_python_type(logs)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 537, in to_numpy_or_python_type
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 635, in map_structure
structure[0], [func(*x) for x in entries],
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 635, in <listcomp>
structure[0], [func(*x) for x in entries],
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 533, in _to_single_numpy_or_python_type
x = t.numpy()
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1063, in numpy
maybe_arr = self._numpy() # pylint: disable=protected-access
File "/home/chris/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1031, in _numpy
six.raise_from(core._status_to_exception(e.code, e.message), None) # pylint: disable=protected-access
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: 9 root error(s) found.
(0) Invalid argument: {{function_node __inference_train_function_96086}} Compilation failure: Detected unsupported operations when trying to compile graph __inference_cond_true_679310_21187_rewritten[] on XLA_TPU_JIT: OptionalFromValue (No registered 'OptionalFromValue' OpKernel for XLA_TPU_JIT devices compatible with node {{node OptionalFromValue}}){{node OptionalFromValue}}
[[functional_1/perturb_layer/StatefulPartitionedCall/StatefulPartitionedCall/cond_0]]
TPU compilation failed
[[tpu_compile_succeeded_assert/_9040394868092548934/_15]]
[[tpu_compile_succeeded_assert/_9040394868092548934/_15/_425]]
注意:root错误相同重复9次
知道错误意味着什么吗?或者是什么原因造成的?
我使用的是 Python 3.6.9 和 tf 2.3.0