在TensorFlow分布式MultiWorkerMirroredStrategy培训期间,我在Early_stopping期间遇到错误。
我正在使用TrainSpec的提前停止标准进行2个节点的分布式训练。通过提前停止钩子,我立即在所有工作节点上同时收到以下错误。
如果卸下了早期停止钩子,则代码完成。
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
run_config = tf.estimator.RunConfig(
model_dir=model_output_dir,
save_checkpoints_steps=5000,
keep_checkpoint_max=1,
train_distribute=strategy
early_stopping_mae = tf.estimator.experimental.stop_if_no_decrease_hook(
estimator, metric_name='mae',
max_steps_without_decrease=20000, min_steps=100)
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: csv_input_fn(
train_filepath,
hparams['batch_size'],
hparams['num_epochs']),
hooks=[early_stopping_mae]
)
随着尽早停止,我收到以下错误。目前尚不清楚多工作者策略是否支持提前停止。
File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/ops/collective_ops.py", line 133, in broadcast_send
instance_key=instance_key)
File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/ops/gen_collective_ops.py", line 159, in collective_bcast_send
shape=shape, name=name)
File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
param_name=input_name)
File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float16, float64, int32, int64