从张量提取多个子矩阵

时间:2018-10-15 12:21:26

标签: python python-3.x tensorflow

很抱歉我不得不问这个问题,因为它看起来很简单,但是我正在尝试在Tensorflow中找到一种专门的方法。

我有一个如下的张量矩阵:

    [0 0 1 1]
X = [0 0 1 1]
    [1 1 0 0]
    [1 1 0 0]

我需要提取两个补丁:

  [1,1]     [1,1]
  [1,1] &   [1,1]

我还给出了指向右上方给出的子矩阵左元素的索引列表。例如。

[[0,2]
 [2,0]]

我需要在Tensorflow中提取补丁。谢谢。

2 个答案:

答案 0 :(得分:1)

好吧,如果您知道需要提取哪些子矩阵,则tf.slice()是最佳选择。

文档为here

对于您提供的示例,使用tf.slice()的解决方案是:

import tensorflow as tf

x = [[0, 0, 1, 1],
     [0, 0, 1, 1],
     [1, 1, 0, 0],
     [1, 1, 0, 0]]
X = tf.Variable(x)
s1 = tf.slice(X, [2,0], [2,2])
s1 = tf.slice(X, [0,2], [2,2])


with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)
  print(sess.run([s1, s1]))

此代码呈现以下结果:

[array([[1, 1], [1, 1]], dtype=int32), 
array([[1, 1], [1, 1]], dtype=int32)]

编辑:

对于更自动,更省事的方法,您可以使用tensorflow中的getitem属性并将其切片,就像对npArray进行切片一样。

代码可能是这样的:

import tensorflow as tf

var = [[0, 0, 1, 1],
       [0, 0, 1, 1],
       [1, 1, 0, 0],
       [1, 1, 0, 0]]
X = tf.Variable(var)
slices = [[0,2], [2,0]]

s = []
for sli in slices:
  y = sli[0]
  x = sli[1]
  s.append(X[y:y+2, x:x+2])


with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)
  print(sess.run(s))

此代码呈现以下结果:

[array([[1, 1], [1, 1]], dtype=int32), 
array([[1, 1], [1, 1]], dtype=int32)]

答案 1 :(得分:1)

您也可以使用tf.gather_nd进行此操作。下面的示例显示了所有工作位,以及您可以使用collect_nd执行的操作。您应该能够构造索引,以便只需要一个collect_nd op即可获取所需的所有子矩阵。我只是包含了变量索引,以表明您可以使用它从预先未知的张量中获取子矩阵。因此,例如,如果您在图形中计算一些东西,并希望基于该子矩阵来获取子矩阵。

import tensorflow as tf
import numpy as np

# build a tensor
x = np.arange(25)
x = np.reshape(x, [5, 5])
y = x + 4
three_d_array = np.stack([x, y], axis=2)
# just so you can see the shape its made of
print(np.all(x == three_d_array[:,:,0]))
print(np.all(y == three_d_array[:,:,1]))
# make it into a tf tensor
three_d_tensor = tf.constant(three_d_array)

# create a variable for tensor valued slice indices
row_0, col_0 = 0, 0
row_1, col_1 = 0, 1
row_2, col_2 = 1, 0
row_3, col_3 = 1, 1
slice_tensor = tf.constant([
    [row_0, col_0],
    [row_1, col_1],
    [row_2, col_2],
    [row_3, col_3]
])
slices = tf.Variable(initial_value=slice_tensor)

# op to get the sub matrices
gather_op = tf.gather_nd(three_d_tensor, slices)

with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)

  submatrices = sess.run(gather_op)
  print(submatrices[0,:] == three_d_array[row_0, col_0])
  print(submatrices[1,:] == three_d_array[row_1, col_1])
  print(submatrices[2,:] == three_d_array[row_2, col_2])
  print(submatrices[3,:] == three_d_array[row_3, col_3])

  # shift down 2 along 2
  offset_top_left = tf.constant([2,2])
  update_variable_op = tf.assign(slices, slices + offset_top_left[None,:])
  sess.run(update_variable_op)
  submatrices = sess.run(gather_op)
  print(submatrices[0, :] == three_d_array[row_0 + 2, col_0 + 2])
  print(submatrices[1, :] == three_d_array[row_1 + 2, col_1 + 2])
  print(submatrices[2, :] == three_d_array[row_2 + 2, col_2 + 2])
  print(submatrices[3, :] == three_d_array[row_3 + 2, col_3 + 2])