减去Tensorflow张量中的列

时间:2017-01-16 22:48:38

标签: tensorflow

在Tensorflow中,我想从另一列中减去2D张量中的一列。我查看了使用tf.split()tf.slice()分成两个张量然后减去,但这似乎不必要地复杂。我目前的方法是将一列乘以-1,然后将reduce_sum列乘以:

input = tf.constant(
        [[5.8, 3.0],
         [4.0, 6.0],
         [7.0, 9.0]])
oneMinusOne = tf.constant([1., -1.])
temp = tf.mul(input, oneMinusOne)
delta = tf.reduce_sum(temp, 1)

似乎仍然不必要地复杂。有没有更简单的方法呢?

1 个答案:

答案 0 :(得分:1)

许多numpy的数组索引在TensorFlow中按预期工作。以下作品:

input = tf.constant(
    [[5.8, 3.0],
     [4.0, 6.0],
     [7.0, 9.0]])
sess = tf.InteractiveSession()
ans = input[:, :1] - input[:, 1:]
print(ans.eval())

array([[ 2.80000019],
   [-2.        ],
   [-2.        ]], dtype=float32)