将函数(tf.square())应用于Tensor的某些值-TensorFlow

时间:2018-11-06 10:26:33

标签: python tensorflow

让我们从一个例子开始。我有一个[3, 3, 3]这样的张量

input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
                     [[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
                     [[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
                    dtype=tf.float32)

例如,我只想将tf.square应用于低于2的元素。 我正在做的是这个

indices = tf.where(input <= 2)
base = tf.zeros_like(input)
ones = tf.constant(1, shape=[18])
mask = tf.scatter_nd(indices, ones, tf.cast(tf.shape(base), indices.dtype))
mask = tf.cast(mask, tf.float32)
masked = tf.square(tf.multiply(input, mask))
neg_mask = 1 - mask
neg_masked = tf.multiply(input, neg_mask)
output = tf.add(masked, neg_masked)

,并且有效。最终的输出是这个

output = tf.constant([[[1.44, 5.29, 4.5], [1.0609, 2.89, 1.44], [2.1, 3.61, 2.25]],
                      [[1.69, 5.76, 4.6], [1.0816, 2.88, 1.69], [2.2, 3.24, 2.56]],
                      [[1.96, 6.25, 4.7], [1.1025, 2.87, 1.96], [2.3, 2.89, 2.89]]],
                     dtype=tf.float32)

问题在于,这很棘手,因为这是一个玩具示例,但在我的情况下,张量具有数千个元素的形状。而且,正如您在此行ones = tf.constant(1, shape=[18])中看到的那样,我已经对18进行了编码,因为我知道它们为18,并且如果尝试使用例如ones = tf.constant(1, shape=[indices.get_shape()[0]]),则会收到此错误TypeError: long returned non-long (type NoneType)

所以我的两个问题是:

  1. 我该如何解决形状问题?
  2. 是否有更快,更有效的方法来做到这一点?

预先感谢


修改

问题1解决了此问题ones = tf.ones(shape=[tf.shape(indices)[0]])

2 个答案:

答案 0 :(得分:1)

最简单的方法是对所有值求平方,然后选择正确的值:

import tensorflow as tf

input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
                     [[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
                     [[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
                    dtype=tf.float32)
output = tf.where(input < 2, tf.square(input), input)
with tf.Session() as sess:
    print(sess.run(output))

输出:

[[[1.44      5.29      4.5      ]
  [1.0609    2.89      1.44     ]
  [2.1       3.61      2.25     ]]

 [[1.6899998 5.76      4.6      ]
  [1.0816    2.88      1.6899998]
  [2.2       3.2399998 2.5600002]]

 [[1.9599999 6.25      4.7      ]
  [1.1024998 2.87      1.9599999]
  [2.3       2.89      2.89     ]]]

如果您有一个非常大的张量,其中只有几个值将被平方,那么您可以考虑仅对必要的值进行平方。但是我不确定这实际上会更快,因为它需要一些额外的工作和中间值,但是您可以对其进行基准测试。如果您要执行的操作特别昂贵,而不仅仅是平方,那么我想这可能会有所作为。这与您所做的大致相同,但更加简单:

import tensorflow as tf

input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
                     [[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
                     [[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
                    dtype=tf.float32)
m = input < 2
v = tf.boolean_mask(input, m)
v2 = tf.square(v)
v2_scatter = tf.scatter_nd(tf.where(m), v2, tf.cast(tf.shape(input), tf.int64))
output = input * tf.cast(~m, input.dtype) + v2_scatter
with tf.Session() as sess:
    print(sess.run(output))
    # Output is the same as before

答案 1 :(得分:1)

这是一个更容易实现但不一定更快的实现方式:

input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
                     [[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
                     [[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
                    dtype=tf.float32)

original_shape = input.get_shape()
input = tf.reshape(input, shape=[-1])
output = tf.map_fn(lambda e:tf.cond(e < 2, lambda:tf.square(e), lambda:e), input)
output = tf.reshape(output, shape=original_shape)

with tf.Session() as sess:
    print(sess.run(output))

输出:

[[[1.44      5.29      4.5      ]
  [1.0609    2.89      1.44     ]
  [2.1       3.61      2.25     ]]

 [[1.6899998 5.76      4.6      ]
  [1.0816    2.88      1.6899998]
  [2.2       3.2399998 2.5600002]]

 [[1.9599999 6.25      4.7      ]
  [1.1024998 2.87      1.9599999]
  [2.3       2.89      2.89     ]]]

tf.map_fn(fn, elems, ...)函数将N维输入elems沿第一维解压缩为多个N-1维子张量,并将fn应用于每个子张量。因此,我将输入重塑为一维张量,对每个元素应用函数,然后将输出重塑为原始形状。