我有一个形状为A
的2D张量[batch_size, D]
和一个形状为B
的1D张量[batch_size]
。对于B
的每一行,A
的每个元素都是A
的列索引,例如。 B[i] in [0,D)
。
tensorflow获取值A[B]
例如:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
具有所需的输出:
some_slice_func(A, B) -> [2,4]
还有另一个限制因素。实际上,batch_size
实际上是None
。
提前致谢!
答案 0 :(得分:3)
我能够使用线性索引使其工作:
def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]
where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)
Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
linear_index = (tf.shape(A)[1]
* tf.range(0,tf.shape(A)[0]))
linear_A = tf.reshape(A, [-1])
return tf.gather(linear_A, B + linear_index)
但这感觉有点哈哈。
如果有人知道更好(如更清楚或更快),也请留下答案! (我暂时不会接受自己的行为)
答案 1 :(得分:1)
@Eugene Brevdo所说的代码:
def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]
where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)
Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
B = tf.expand_dims(B, 1)
range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1)
ind = tf.concat([range, B], 1)
return tf.gather_nd(A, ind)
答案 2 :(得分:0)
最简单的方法可能是通过连接范围(batch_size)和B来构建适当的2d索引,以获得batch_size x 2矩阵。然后将其传递给tf.gather_nd。
答案 3 :(得分:0)
最简单的方法是:
CREATE TABLE ProjectCreationTasks (
Id text NOT NULL PRIMARY KEY,
ProjectName text,
ProjectCode text,
DenialReason text,
CONSTRAINT my_constraint CHECK
((ProjectName IS NULL AND ProjectCode IS NULL AND DenialReason IS NULL)
OR(ProjectName IS NOT NULL AND ProjectCode IS NOT NULL AND DenialReason IS NULL)
OR(DenialReason IS NOT NULL AND ProjectName IS NULL AND ProjectCode IS NULL))
);
答案 4 :(得分:0)
考虑使用tf.one_hot
、tf.math.multiply
和tf.reduce_sum
来解决。
例如
def vector_slice (inputs, inds, axis = None):
axis = axis if axis is not None else tf.rank(inds) - 1
inds = tf.one_hot(inds, inputs.shape[axis])
for i in tf.range(tf.rank(inputs) - tf.rank(inds)):
inds = tf.expand_dims(inds, axis = -1)
inds = tf.cast(inds, dtype = inputs.dtype)
x = tf.multiply(inputs, inds)
return tf.reduce_sum(x, axis = axis)