我有一个稀疏的张量,我从一系列指数和价值观中积累起来。我试图实现一些代码来获取完整的行切片。虽然TensorFlow似乎没有直接支持此功能,但似乎有一些解决办法可以返回指定行的索引和值,如下所示:
def sparse_slice(self, indices, values, needed_row_ids):
needed_row_ids = tf.reshape(needed_row_ids, [1, -1])
num_rows = tf.shape(indices)[0]
partitions = tf.cast(tf.reduce_any(tf.equal(tf.reshape(indices[:, 0], [-1, 1]), needed_row_ids), 1),
tf.int32)
rows_to_gather = tf.dynamic_partition(tf.range(num_rows), partitions, 2)[1]
slice_indices = tf.gather(indices, rows_to_gather)
slice_values = tf.gather(values, rows_to_gather)
return slice_indices, slice_values
然后直接调用稀疏4x4矩阵,我有兴趣访问第3行中的所有元素:
with tf.Session().as_default():
indices = tf.constant([[0, 0], [1, 0], [2, 0], [2, 1], [3, 0], [3, 3]])
values = tf.constant([10, 19, 1, 1, 6, 5], dtype=tf.int64)
needed_row_ids = tf.constant([3])
slice_indices, slice_values = self.sparse_slice(indices, values, needed_row_ids)
print('indicies: {} and rows: {}'.format(slice_indices.eval(), slice_values.eval()))
其中输出以下内容:
indicies: [[3 0]
[3 3]] and rows: [6 5]
到目前为止一切顺利,我想我可以使用这些信息构建一个 1x4密集张量,索引值为0,缺失列为0。
dense_representation = tf.sparse_to_dense(sparse_values=slice_values, sparse_indices=slice_indices,
output_shape=(1,4))
然而,我在会话中运行张量的那一刻。
sess = tf.Session()
sess.run(dense_representation)
我收到以下例外:
InvalidArgumentError (see above for traceback): indices[0] = [3,0] is out of bounds: need 0 <= index < [1,4]
[[Node: SparseToDense = SparseToDense[T=DT_INT64, Tindices=DT_INT32, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Gather_2, SparseToDense/output_shape, Gather_3, SparseToDense/default_value)]]
我不太确定我做错了什么,或者这与output_shape
没有正确形成有什么关系。基本上我想把这一切都装回1 x 4矢量。我还没有在网上找到任何好的例子来说明如何做到这一点。任何帮助,将不胜感激。