让我们从一个例子开始。我有一个[3, 3, 3]
这样的张量
input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
[[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
[[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
dtype=tf.float32)
例如,我只想将tf.square
应用于低于2
的元素。
我正在做的是这个
indices = tf.where(input <= 2)
base = tf.zeros_like(input)
ones = tf.constant(1, shape=[18])
mask = tf.scatter_nd(indices, ones, tf.cast(tf.shape(base), indices.dtype))
mask = tf.cast(mask, tf.float32)
masked = tf.square(tf.multiply(input, mask))
neg_mask = 1 - mask
neg_masked = tf.multiply(input, neg_mask)
output = tf.add(masked, neg_masked)
,并且有效。最终的输出是这个
output = tf.constant([[[1.44, 5.29, 4.5], [1.0609, 2.89, 1.44], [2.1, 3.61, 2.25]],
[[1.69, 5.76, 4.6], [1.0816, 2.88, 1.69], [2.2, 3.24, 2.56]],
[[1.96, 6.25, 4.7], [1.1025, 2.87, 1.96], [2.3, 2.89, 2.89]]],
dtype=tf.float32)
问题在于,这很棘手,因为这是一个玩具示例,但在我的情况下,张量具有数千个元素的形状。而且,正如您在此行ones = tf.constant(1, shape=[18])
中看到的那样,我已经对18
进行了编码,因为我知道它们为18,并且如果尝试使用例如ones = tf.constant(1, shape=[indices.get_shape()[0]])
,则会收到此错误TypeError: long returned non-long (type NoneType)
所以我的两个问题是:
预先感谢
修改
问题1解决了此问题ones = tf.ones(shape=[tf.shape(indices)[0]])
答案 0 :(得分:1)
最简单的方法是对所有值求平方,然后选择正确的值:
import tensorflow as tf
input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
[[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
[[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
dtype=tf.float32)
output = tf.where(input < 2, tf.square(input), input)
with tf.Session() as sess:
print(sess.run(output))
输出:
[[[1.44 5.29 4.5 ]
[1.0609 2.89 1.44 ]
[2.1 3.61 2.25 ]]
[[1.6899998 5.76 4.6 ]
[1.0816 2.88 1.6899998]
[2.2 3.2399998 2.5600002]]
[[1.9599999 6.25 4.7 ]
[1.1024998 2.87 1.9599999]
[2.3 2.89 2.89 ]]]
如果您有一个非常大的张量,其中只有几个值将被平方,那么您可以考虑仅对必要的值进行平方。但是我不确定这实际上会更快,因为它需要一些额外的工作和中间值,但是您可以对其进行基准测试。如果您要执行的操作特别昂贵,而不仅仅是平方,那么我想这可能会有所作为。这与您所做的大致相同,但更加简单:
import tensorflow as tf
input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
[[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
[[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
dtype=tf.float32)
m = input < 2
v = tf.boolean_mask(input, m)
v2 = tf.square(v)
v2_scatter = tf.scatter_nd(tf.where(m), v2, tf.cast(tf.shape(input), tf.int64))
output = input * tf.cast(~m, input.dtype) + v2_scatter
with tf.Session() as sess:
print(sess.run(output))
# Output is the same as before
答案 1 :(得分:1)
这是一个更容易实现但不一定更快的实现方式:
input = tf.constant([[[1.2, -2.3, 4.5], [1.03, 2.89, -1.2], [2.1, 1.9, -1.5]],
[[1.3, -2.4, 4.6], [1.04, 2.88, -1.3], [2.2, 1.8, -1.6]],
[[1.4, -2.5, 4.7], [1.05, 2.87, -1.4], [2.3, 1.7, -1.7]]],
dtype=tf.float32)
original_shape = input.get_shape()
input = tf.reshape(input, shape=[-1])
output = tf.map_fn(lambda e:tf.cond(e < 2, lambda:tf.square(e), lambda:e), input)
output = tf.reshape(output, shape=original_shape)
with tf.Session() as sess:
print(sess.run(output))
输出:
[[[1.44 5.29 4.5 ]
[1.0609 2.89 1.44 ]
[2.1 3.61 2.25 ]]
[[1.6899998 5.76 4.6 ]
[1.0816 2.88 1.6899998]
[2.2 3.2399998 2.5600002]]
[[1.9599999 6.25 4.7 ]
[1.1024998 2.87 1.9599999]
[2.3 2.89 2.89 ]]]
tf.map_fn(fn, elems, ...)
函数将N维输入elems
沿第一维解压缩为多个N-1维子张量,并将fn
应用于每个子张量。因此,我将输入重塑为一维张量,对每个元素应用函数,然后将输出重塑为原始形状。