Tensorflow中的块对角矩阵

时间:2017-02-10 11:02:43

标签: tensorflow

假设我有一些形状不同的张量A_i [N_i,N_i]。在张量流中是否有可能在对角线上创建具有这些矩阵的块对角矩阵?我现在能想到的唯一方法是通过堆叠和添加tf.zeros来完全构建它。

2 个答案:

答案 0 :(得分:6)

我同意有一个C ++操作可以做到这一点很好。与此同时,这就是我所做的(获取静态形状信息的权利有点繁琐):

import tensorflow as tf

def block_diagonal(matrices, dtype=tf.float32):
  r"""Constructs block-diagonal matrices from a list of batched 2D tensors.

  Args:
    matrices: A list of Tensors with shape [..., N_i, M_i] (i.e. a list of
      matrices with the same batch dimension).
    dtype: Data type to use. The Tensors in `matrices` must match this dtype.
  Returns:
    A matrix with the input matrices stacked along its main diagonal, having
    shape [..., \sum_i N_i, \sum_i M_i].

  """
  matrices = [tf.convert_to_tensor(matrix, dtype=dtype) for matrix in matrices]
  blocked_rows = tf.Dimension(0)
  blocked_cols = tf.Dimension(0)
  batch_shape = tf.TensorShape(None)
  for matrix in matrices:
    full_matrix_shape = matrix.get_shape().with_rank_at_least(2)
    batch_shape = batch_shape.merge_with(full_matrix_shape[:-2])
    blocked_rows += full_matrix_shape[-2]
    blocked_cols += full_matrix_shape[-1]
  ret_columns_list = []
  for matrix in matrices:
    matrix_shape = tf.shape(matrix)
    ret_columns_list.append(matrix_shape[-1])
  ret_columns = tf.add_n(ret_columns_list)
  row_blocks = []
  current_column = 0
  for matrix in matrices:
    matrix_shape = tf.shape(matrix)
    row_before_length = current_column
    current_column += matrix_shape[-1]
    row_after_length = ret_columns - current_column
    row_blocks.append(tf.pad(
        tensor=matrix,
        paddings=tf.concat(
            [tf.zeros([tf.rank(matrix) - 1, 2], dtype=tf.int32),
             [(row_before_length, row_after_length)]],
            axis=0)))
  blocked = tf.concat(row_blocks, -2)
  blocked.set_shape(batch_shape.concatenate((blocked_rows, blocked_cols)))
  return blocked

举个例子:

blocked_tensor = block_diagonal(
    [tf.constant([[1.]]),
     tf.constant([[1., 2.], [3., 4.]])])

with tf.Session():
  print(blocked_tensor.eval())

打印:

[[ 1.  0.  0.]
 [ 0.  1.  2.]
 [ 0.  3.  4.]]

答案 1 :(得分:2)

对于现在访问此页面的任何人-张量流现在有tf.linalg.LinearOperatorBlockDiag。遵循上面艾伦的示例:

import tensorflow as tf

tfl = tf.linalg

blocks = [tf.constant([[1.0]]), tf.constant([[1.0, 2.0], [3.0, 4.0]])]

linop_blocks = [tfl.LinearOperatorFullMatrix(block) for block in blocks]
linop_block_diagonal = tfl.LinearOperatorBlockDiag(linop_blocks)

>>> print(linop_block_diagonal.to_dense())
tf.Tensor(
[[1. 0. 0.]
 [0. 1. 2.]
 [0. 3. 4.]], shape=(3, 3), dtype=float32)