从Tensorflow

时间:2017-11-23 03:40:12

标签: python tensorflow

我正在寻找一种简单的方法来从Tensorflow中的当前张量中删除一组张量,并且我有一个难以合理的解决方案。

例如,假设我有以下当前张量:

a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')

我想要删除这个张量中的两个项目(2.0和5.0)。

创建后将张量转换为[1.0,3.0,4.0,6.0]的最佳方法是什么?

非常感谢提前。

2 个答案:

答案 0 :(得分:2)

您可以致电tf.unstack以获取子张量的列表。然后,您可以修改列表并调用tf.stack从列表中构造张量。例如,以下代码从:

中删除[2.0,5.0]列
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
a_vecs = tf.unstack(a, axis=1)
del a_vecs[1]
a_new = tf.stack(a_vecs, 1)

答案 1 :(得分:1)

另一种方法是使用分割或切片功能。如果张量很大,这将非常有用。

方法1:使用split功能。

a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
split1, split2, split3 = tf.split(a, [1, 1, 1], 1)
a_new = tf.concat([split1, split3], 1)

方法2 :使用slice函数。

slice1 = tf.slice(a, [0, 0], [2, 1])
slice2 = tf.slice(a, [0, 2], [2, 1])
a_new = tf.concat([slice1, slice2], 1)

在这两种情况下,a_new都有

[[ 1.  3.]
 [ 4.  6.]]