取一个在另一个张量内的张量元素

时间:2018-10-17 08:13:22

标签: python tensorflow

我有两个张量,并且我必须迭代第一个张量,以仅采用另一个张量内的元素。 t2中只有一个元素也位于t1内部。这是一个例子

t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]
t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]

t3 = .... # [3, 0]

我尝试使用.eval()进行评估和迭代,并使用运算符t2检查t1的元素是否在in中,但没有工作。 TensorFlow有功能可以做到这一点吗?

修改

for index in xrange(max_indices):
    indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]
    cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]
    indices_list.append(indices)
    for cent in cent_indices:
        if cent in indices:
           centers_list.append(cent)
           break

第一次迭代cent的值为[6 0],但它进入了if条件。

answer

for index in xrange(max_indices):
    indices = tf.where(tf.equal(values, (index + 1))).eval()
    cent_indices = tf.where(centers > 0).eval()
    indices_list.append(indices)
    for cent in cent_indices:
        # batch_item is an iterator from an outer loop
        if values[batch_item, cent[0]].eval() == (index + 1):
           centers_list.append(tf.constant(cent))
           break

解决方案与我的任务有关,但是如果您正在寻找一维张量的解决方案,我建议您看看tf.sets.set_intersection

1 个答案:

答案 0 :(得分:1)

这就是您想要的吗?我只用了这两个测试用例。

x = tf.constant([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 1]])
y = tf.constant([[1, 2, 3, 4, 3, 6], [1, 2, 3, 4, 5, 1]])
# x = tf.constant([[1, 2], [4, 5], [7, 7]])
# y = tf.constant([[7, 7], [3, 5]])

def match(xiterations, yiterations, yvalues, xvalues ):
    for i in range(xiterations):
        for j in range(yiterations):
            if (np.array_equal(yvalues[j], xvalues[i])):
                print( yvalues[j])

with tf.Session() as sess:
    xindex = tf.where( x > 4 )
    yindex = tf.where( y > 4 )

    xvalues = xindex.eval()
    yvalues = yindex.eval()

    xiterations =  tf.shape(xvalues)[0].eval()
    yiterations =  tf.shape(yvalues)[0].eval()

    print(tf.shape(xvalues)[0].eval())
    print(tf.shape(yvalues)[0].eval())

    if tf.shape(xvalues)[0].eval() >= tf.shape(yvalues)[0].eval():
        match( xiterations, yiterations, yvalues, xvalues)
    else:
        match( yiterations, xiterations, xvalues, yvalues)