tf.control依赖项无法按预期运行

时间:2019-01-11 18:20:22

标签: python-2.7 tensorflow

我有一个代码段,可用于在Tensorflow中创建高斯热图。但是,热图的数量是动态的,并且取决于运行时的某些值。问题是我在同一张地图上看到多个高斯,这意味着未执行Assign op OP。有人可以告诉我为什么吗?

        def cond_2(num_rois_jt_ind, tf_hm):
            # gts_x are of shape 17
            return tf.less(num_rois_jt_ind, tf.shape(scores)[0] * self.params["num_masks"])

        def body_2(num_rois_jt_ind, tf_hm):

            def gen_gaussian(num_rois, jt_ind):
                local_hm = tf.Variable(lambda: tf.zeros([1, self.params['output_res'], self.params['output_res']]))
                op = local_hm.assign(tf.zeros([1, self.params['output_res'], self.params['output_res']]))
                with tf.control_dependencies([op]):
                    local_hm_ass = tf.cond(tf.logical_and(tf.logical_and(tf.greater_equal(gt_x[num_rois, jt_ind], 0.), tf.less_equal(gt_x[num_rois, jt_ind], self.params['output_res']-1)),
                    tf.logical_and(tf.greater_equal(gt_y[num_rois, jt_ind], 0.), tf.less_equal(gt_y[num_rois, jt_ind], self.params['output_res']-1)))  ,lambda : local_hm[0, tf_aa[num_rois, jt_ind]:tf_bb[num_rois, jt_ind],
                               tf_cc[num_rois, jt_ind]:tf_dd[num_rois, jt_ind]].assign(
                        tf_g[tf_a[num_rois, jt_ind]:tf_b[num_rois, jt_ind],
                        tf_c[num_rois, jt_ind]:tf_d[num_rois, jt_ind]]), lambda : local_hm )
                    local_tf_hm = tf.identity(local_hm_ass)
                    return local_tf_hm

            jt_ind = tf.mod(num_rois_jt_ind, self.params["num_masks"])
            num_rois = tf.floordiv(num_rois_jt_ind, self.params["num_masks"])

            with tf.variable_scope('local_hm', reuse=False):
                hm_t = gen_gaussian(num_rois,
                                        jt_ind)  # tf.cond(to_use[num_rois, jt_ind], lambda: gen_gaussian(num_rois, jt_ind),  no_gauss)
                tf_hm = tf.concat([tf_hm, hm_t], axis=0)
                return tf.add(num_rois_jt_ind, 1), tf_hm

        rois, tf_hm = tf.while_loop(cond_2, body_2, [tf.Variable(0, dtype=tf.int32), tf_hm],
                                    shape_invariants=[tf.Variable(0, dtype=tf.int32).get_shape(), tf.TensorShape(
                                        [None, self.params['output_res'], self.params['output_res']])])

0 个答案:

没有答案