Tensor-flow分布式培训的问题(提前停止)

时间:2019-09-11 23:11:18

标签: tensorflow distributed tensorflow-estimator

在TensorFlow分布式MultiWorkerMirroredStrategy培训期间,我在Early_stopping期间遇到错误。

我正在使用TrainSpec的提前停止标准进行2个节点的分布式训练。通过提前停止钩子,我立即在所有工作节点上同时收到以下错误。

如果卸下了早期停止钩子,则代码完成。

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

run_config = tf.estimator.RunConfig(
                    model_dir=model_output_dir,
                    save_checkpoints_steps=5000,
                    keep_checkpoint_max=1,
                    train_distribute=strategy

early_stopping_mae = tf.estimator.experimental.stop_if_no_decrease_hook(
                estimator, metric_name='mae', 
                max_steps_without_decrease=20000, min_steps=100)

train_spec = tf.estimator.TrainSpec(
                input_fn=lambda: csv_input_fn(
                    train_filepath,
                    hparams['batch_size'],
                    hparams['num_epochs']),
                hooks=[early_stopping_mae]
                )

随着尽早停止,我收到以下错误。目前尚不清楚多工作者策略是否支持提前停止。

  File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/ops/collective_ops.py", line 133, in broadcast_send
    instance_key=instance_key)
  File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/ops/gen_collective_ops.py", line 159, in collective_bcast_send
    shape=shape, name=name)
  File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
    param_name=input_name)
  File "/home/ab981s/anaconda3/envs/py2tensorflow_nightly/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float16, float64, int32, int64

0 个答案:

没有答案