SyncReplicasOptimizer挂钩的MonitoredTrainingSession无法使用占位符初始化

时间:2018-07-11 14:14:10

标签: python tensorflow

  1. 我将tf.keras.Input作为输入层来构建我的网络。

      

    input_image = tf.keras.Input(shape =(None,None,3),name ='input_image')

  2. 按照网络定义,我这样定义优化器: enter image description here

  3. 然后,我将钩子传递给MonitoredTrainingSession enter image description here

  4. 最后,在运行以创建MonitoredTrainingSession时,会出现占位符错误:

      

    tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须使用dtype float和形状[?,?,?,3]输入占位符张量'input_image'的值   由ready_value = sess.run(op)调用的session_manager.py中的hook.after_create_session(self.tf_sess, self.coord)引发了此异常

关于带有占位符的SyncReplicasOptimizer的任何想法都令人赞赏。

1 个答案:

答案 0 :(得分:0)

最终找到了引起问题的位置: SyncReplicaOptimizer和批处理规范化更新之间的代码中存在冲突。

with tf.control_dependencies(update_ops):
    self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)

此样式更新将引发异常,但在分离控件依赖项之后,改为同时运行train_op和update_ops可以避免此问题。

但是无论如何,我还没有找到解决问题的见解原因或下降的方式。