通过张量迭代张量流

时间:2019-08-19 10:34:07

标签: python tensorflow

我使用Tensorflow 1.13。但是我得到一个错误,说除非我处于急切模式,否则我无法遍历张量。有没有办法进入急切模式?

with tf.Session(config=config) as sess:
    context = tf.placeholder(tf.int32, [args.batch_size, None])
    mask = tf.placeholder(tf.int32, [args.batch_size, 2])
    output = model.model(hparams=hparams, X=context)



    for batch_index in range(args.batch_size):
        start = mask[batch_index][0]
        end   = mask[batch_index][1]

        for i in range(start, end+1):
            output['logits'][batch_index, i , context[batch_index,i]].assign(math.inf)

    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:],  logits=output['logits'][:, :-1]))

1 个答案:

答案 0 :(得分:0)

您可以尝试使用tf.while_loop吗?您可以尝试以下代码段(可能需要对代码进行少量修改),看看它是否有效?

import tensorflow as tf
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    context = tf.placeholder(tf.int32, [args.batch_size, None])
    mask = tf.placeholder(tf.int32, [args.batch_size, 2])
    output = model.model(hparams=hparams, X=context)



    for batch_index in [0,1,2,3]: #I have assumed a dummy list cz we can't iterate through a 'Dimension'
        start = mask[batch_index][0]
        end   = mask[batch_index][1]

        i = tf.constant(0)
        while_condition = lambda i: (tf.less(i, end)) & (tf.math.greater_equal(i,start))

        def body(i):
            return output['logits'][batch, i , context[batch,i]].assign(math.inf)

        r = tf.while_loop(while_condition, body, [i])

        # for i in range(start, end+1):
        #     output['logits'][batch, i , context[batch,i]].assign(math.inf)

    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=context[:, 1:],  logits=output['logits'][:, :-1]))