如何在Tensorflow中将不同的变量传递到while循环的主体和条件中?

时间:2018-08-13 13:50:20

标签: python tensorflow

我尝试在 Tensorflow 中使用 while循环

代码为:

import tensorflow as tf

sess=tf.Session()


rois_boxes = tf.concat([tf.ones([12,5]),tf.zeros([12,5]) ], axis=0)

img_ids = tf.unique(rois_boxes[:,0])
img_ids = tf.cast(img_ids[0], tf.int32)



regions_features=tf.constant(55, dtype=tf.int32)

def body(regions_features, img_ids):
        regions_features = img_ids[0]
        img_ids = img_ids[1:]
        return regions_features


def condition(regions_features, img_ids):
        return tf.not_equal(tf.size(img_ids), 0)


x = tf.Variable(tf.constant(0, shape=[2, 2]))

regions_features = tf.while_loop(condition, body, [regions_features, img_ids])

此代码返回此错误::

  

回溯(最近一次通话最后一次):文件“”,第1行,在      文件   “ /home/ashwaq/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py”,   第2775行,在while_loop中       结果= context.BuildLoop(cond,body,loop_vars,shape_invariants)文件   “ /home/ashwaq/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py”,   BuildLoop中的第2604行       pred,body,original_loop_vars,loop_vars,shape_invariants)文件   “ /home/ashwaq/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py”,   _BuildLoop中的2561行       nest.assert_same_structure(list(packed_vars_for_body),list(body_result))文件   “ /home/ashwaq/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/util/nest.py”,   第199行,在assert_same_structure中       %(len_nest1,nest1,len_nest2,nest2))   ValueError:两个结构没有相同数量的元素。

     

第一个结构(2个元素):[<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'while/Identity_1:0' shape=(?,) dtype=int32>]

     

第二个结构(1个元素):[<tf.Tensor 'while/strided_slice_1:0' shape=() dtype=int32>]

为什么会发生此问题?以及如何将不同的变量传递到while循环的主体和条件中而没有任何问题?

1 个答案:

答案 0 :(得分:0)

这对我有用,发现2 ^ 5。

i = tf.get_variable('i', initializer=tf.constant(1))
pow_2 = tf.get_variable('X', initializer=tf.constant(2))
def cond(tensor):
    return tensor[0] < 5
def body(tensor):
    return tf.stack([tensor[0] + 1, tensor[1] * 2])

T = tf.while_loop(cond, body, loop_vars=[tf.stack([i, pow_2])])