如何得到真实的'真实'基于轴的Tensorflow中的值?

时间:2017-08-24 03:03:11

标签: tensorflow

我试图提取' True'基于特定轴的张量流量张量。

预期值

  

[[真,假,真,假],

     

[假,假,真,假]]

  

[[0,2],

     

[2]]

我发现了

'tf.where'

与我的预期类似,但

当我使用这个功能时,结果是

  

[[0,0],          [0,2],          [1,2]],

有没有办法获得' True'的指数?根据具体轴的值?

1 个答案:

答案 0 :(得分:2)

这应该可以为您提供所需的内容 -

import tensorflow as tf
import numpy as np



inputs = tf.placeholder(dtype=tf.bool, shape=(2,4)) 
time_steps = tf.shape(inputs)[0]
initial_outputs = tf.TensorArray(dtype=tf.int32, size=time_steps)
initial_t = tf.placeholder(dtype='int32')

def cond(t, *args):
    return t < time_steps

def body(t, outputs_):
    sub = tf.gather(inputs, t)
    cur = tf.squeeze(tf.cast(tf.where(sub), tf.int32))
    outputs_ = outputs_.write(t, cur)
    return t + 1, outputs_

t, outputs = tf.while_loop(cond, body,[initial_t, initial_outputs])

outputs = outputs.stack()
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run([init])
    print(outputs)
    print(sess.run([outputs], feed_dict={inputs: np.asarray([[True,False,True,False],[True, True, False, False]]), initial_t:0}))

此编辑的灵感来自chrert对我的答案的先前版本的评论。该解决方案实际上并不需要指定轴,我已经修改它以反映这一点(但它确实取轴= 0,这可能被认为是限制性的)。他是对的,tf.while_loop和TensorArrays可以让你在不知道它的形状的情况下循环任何给定形状的张量;并且很高兴知道如何使用动态形状的张量!但是,对于bj1123的特定用例,在尝试堆叠结果时可能会失败。这是因为每个行或切片都可以(并且很可能会!!)具有不同的True和False值计数。这将引发错误“InvalidArgumentError(参见上面的回溯):TensorArray具有不一致的形状。”量化我想说的话 -

np.asarray([[True,True,True,False],[True, True, False, False]])

上面的代码运行正常,现在尝试用

替换占位符
{{1}}

现在您可以看到错误。似乎没有任何直接的方法将不规则形状的张量堆叠成单个张量。唯一的方法是在我的第一个版本中显示张量列表。 我也编辑了