tensorflow在python中嵌套while循环

时间:2017-08-14 11:16:48

标签: python tensorflow while-loop

def body_4(i, indices_set):
    c_j = lambda j, indices_set_j : tf.less(j, len_src_sent)
    j = tf.constant(0)
    indices_set_j = indices_set
    def body_j(j, indices_set_j):
        align_middle_ = ALIGNMENT_SIZE / 2
        align_start_ = 0 - j
        align_end_ = len_src_sent - j
        c_k = lambda k, indices_set_k: tf.less(k, len_src_sent)
        k = tf.constant(1)
        indices_set_k = indices_set_j
        def body_k(k, indices_set_k):
            indices_set_k = tf.concat([indices_set_k, tf.stack([tf.cast([i*len_src_sent+j, align_middle_ + align_start_ + k], tf.int64)])], 0)
            k = tf.add(k, 1)
            return k, indices_set_k
        [index, indices_set_k] = tf.while_loop(c_k, body_k, loop_vars=[k, indices_set_k], shape_invariants=[k.get_shape(), tf.TensorShape([None, None])])
        j = tf.add(j,1)
        indices_set_j = indices_set_k
        return j, indices_set_j
    [index, indices_set] = tf.while_loop(c_j, body_j, loop_vars=[j, indices_set_j], shape_invariants=[j.get_shape(), tf.TensorShape([None, None])])
    i = tf.add(i, 1)
    return i, indices_set

c = lambda i, indices_set: tf.less(i, len_trg_sent-1)
i = tf.constant(0)
indices_set = tf.cast([0, ALIGNMENT_SIZE/2], tf.int64)
indices_set = tf.stack([indices_set])
[index, indices_set] = tf.while_loop(c, body_4, loop_vars=[i, indices_set], shape_invariants=[i.get_shape(), tf.TensorShape([None, None])])

我想创建一个张量流图来输出一些indices_set供以后使用sparsetensor。 indice元素应该看起来像[i * len_trg_sent + j,align_middle_ + align_start_ + k]其中i是第一轴的索引,j第二轴和具有形状的张量的第三个(len_trg_sent-1,len_src_sent,len_src_sent)

但上面的代码似乎只是一个死循环。我对tensorflow中的嵌套while循环感到困惑,因此如果有人能帮助我,我会很感激。

0 个答案:

没有答案