从出列的一批值,索引,形状创建SparseTensor

时间:2017-02-09 21:38:00

标签: python tensorflow

我正在尝试将RAM中的模型(其中一些是稀疏的)提供给模型。我已经创建了一个PaddingFIFOQueue,我将稀疏张量的索引,值和形状分别排队,假设稀疏值不能通过其他方法从RAM进行批处理(如果不是这样,请告诉我)。它们需要填充,因为序列的长度都不同。

我出了以下几个......

indices = [batch size, None, 2]
values = [batch size, None]
shapes = [batch size, 2]

我试图使用这些值来创建SparseTensor,但收到以下错误。

ValueError: Shape (512, ?, 2) must have rank 2

代码的主要部分如下......

indices, values, shapes = self.queue.dequeue_many(batch_size)
sp_tensor = tf.SparseTensor(indices, values, shapes)

我认为这是因为SparseTensor期望排名2张量而不是一批排名2张量(如错误消息所示),但我不确定如何转换批次。

1 个答案:

答案 0 :(得分:2)

这可以通过一些平铺和重塑来实现:

import tensorflow as tf

def sparse_tensor_merge(indices, values, shape):
  """Creates a SparseTensor from batched indices, values, and shapes.

  Args:
    indices: A [batch_size, N, D] integer Tensor.
    values: A [batch_size, N] Tensor of any dtype.
    shape: A [batch_size, D] Integer Tensor.
  Returns:
    A SparseTensor of dimension D + 1 with batch_size as its first dimension.
  """
  merged_shape = tf.reduce_max(shape, axis=0)
  batch_size, elements, shape_dim = tf.unstack(tf.shape(indices))
  index_range_tiled = tf.tile(tf.range(batch_size)[..., None],
                              tf.stack([1, elements]))[..., None]
  merged_indices = tf.reshape(
      tf.concat([tf.cast(index_range_tiled, tf.int64), indices], axis=2),
      [-1, 1 + tf.size(merged_shape)])
  merged_values = tf.reshape(values, [-1])
  return tf.SparseTensor(
      merged_indices, merged_values,
      tf.concat([[tf.cast(batch_size, tf.int64)], merged_shape], axis=0))

例如:

batch_indices = tf.constant(
    [[[0, 0], [0, 1]],
     [[0, 0], [1, 1]]], dtype=tf.int64)
batch_values = tf.constant(
    [[0.1, 0.2],
     [0.3, 0.4]])
batch_shapes = tf.constant(
    [[2, 2],
     [3, 2]], dtype=tf.int64)

merged = sparse_tensor_merge(batch_indices, batch_values, batch_shapes)

with tf.Session():
  print(merged.eval())

打印:

SparseTensorValue(indices=array([[0, 0, 0],
       [0, 0, 1],
       [1, 0, 0],
       [1, 1, 1]]), 
  values=array([ 0.1       ,  0.2       ,  0.30000001,  0.40000001],
      dtype=float32), 
  dense_shape=array([2, 3, 2]))

请注意,组合的SparseTensor的形状是原始批次维度,后面是每个其他维度的批次最大值。