很抱歉我不得不问这个问题,因为它看起来很简单,但是我正在尝试在Tensorflow中找到一种专门的方法。
我有一个如下的张量矩阵:
[0 0 1 1]
X = [0 0 1 1]
[1 1 0 0]
[1 1 0 0]
我需要提取两个补丁:
[1,1] [1,1]
[1,1] & [1,1]
我还给出了指向右上方给出的子矩阵左元素的索引列表。例如。
[[0,2]
[2,0]]
我需要在Tensorflow中提取补丁。谢谢。
答案 0 :(得分:1)
好吧,如果您知道需要提取哪些子矩阵,则tf.slice()是最佳选择。
文档为here
对于您提供的示例,使用tf.slice()的解决方案是:
import tensorflow as tf
x = [[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]
X = tf.Variable(x)
s1 = tf.slice(X, [2,0], [2,2])
s1 = tf.slice(X, [0,2], [2,2])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run([s1, s1]))
此代码呈现以下结果:
[array([[1, 1], [1, 1]], dtype=int32),
array([[1, 1], [1, 1]], dtype=int32)]
编辑:
对于更自动,更省事的方法,您可以使用tensorflow中的getitem属性并将其切片,就像对npArray进行切片一样。
代码可能是这样的:
import tensorflow as tf
var = [[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]
X = tf.Variable(var)
slices = [[0,2], [2,0]]
s = []
for sli in slices:
y = sli[0]
x = sli[1]
s.append(X[y:y+2, x:x+2])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(s))
此代码呈现以下结果:
[array([[1, 1], [1, 1]], dtype=int32),
array([[1, 1], [1, 1]], dtype=int32)]
答案 1 :(得分:1)
您也可以使用tf.gather_nd进行此操作。下面的示例显示了所有工作位,以及您可以使用collect_nd执行的操作。您应该能够构造索引,以便只需要一个collect_nd op即可获取所需的所有子矩阵。我只是包含了变量索引,以表明您可以使用它从预先未知的张量中获取子矩阵。因此,例如,如果您在图形中计算一些东西,并希望基于该子矩阵来获取子矩阵。
import tensorflow as tf
import numpy as np
# build a tensor
x = np.arange(25)
x = np.reshape(x, [5, 5])
y = x + 4
three_d_array = np.stack([x, y], axis=2)
# just so you can see the shape its made of
print(np.all(x == three_d_array[:,:,0]))
print(np.all(y == three_d_array[:,:,1]))
# make it into a tf tensor
three_d_tensor = tf.constant(three_d_array)
# create a variable for tensor valued slice indices
row_0, col_0 = 0, 0
row_1, col_1 = 0, 1
row_2, col_2 = 1, 0
row_3, col_3 = 1, 1
slice_tensor = tf.constant([
[row_0, col_0],
[row_1, col_1],
[row_2, col_2],
[row_3, col_3]
])
slices = tf.Variable(initial_value=slice_tensor)
# op to get the sub matrices
gather_op = tf.gather_nd(three_d_tensor, slices)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
submatrices = sess.run(gather_op)
print(submatrices[0,:] == three_d_array[row_0, col_0])
print(submatrices[1,:] == three_d_array[row_1, col_1])
print(submatrices[2,:] == three_d_array[row_2, col_2])
print(submatrices[3,:] == three_d_array[row_3, col_3])
# shift down 2 along 2
offset_top_left = tf.constant([2,2])
update_variable_op = tf.assign(slices, slices + offset_top_left[None,:])
sess.run(update_variable_op)
submatrices = sess.run(gather_op)
print(submatrices[0, :] == three_d_array[row_0 + 2, col_0 + 2])
print(submatrices[1, :] == three_d_array[row_1 + 2, col_1 + 2])
print(submatrices[2, :] == three_d_array[row_2 + 2, col_2 + 2])
print(submatrices[3, :] == three_d_array[row_3 + 2, col_3 + 2])