当切片本身是张量流中的张量时如何进行切片分配

时间:2019-06-19 18:09:13

标签: python tensorflow slice variable-assignment

我想在张量流中进行切片分配。我知道我可以使用:

my_var = my_var[4:8].assign(tf.zeros(4))

基于此link

正如您在my_var[4:8]中所看到的,我们在此处有特定的索引4、8,用于切片和赋值。

我的情况有所不同,我想基于张量进行切片,然后进行赋值。

out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

changed_tensor = [[8.3356,    0.,        8.457685 ],
                  [0.,        6.103182,  8.602337 ],
                  [8.8974,    7.330564,  0.       ],
                  [0.,        3.8914037, 5.826657 ],
                  [8.8974,    0.,        8.283971 ],
                  [6.103182,  3.0614321, 5.826657 ],
                  [7.330564,  0.,        8.283971 ],
                  [6.103182,  3.8914037, 0.       ]]

这也是sparse_indices张量,它是rows_tfcolumns_tf的结合,使得整个索引都需要更新(以防万一:)

sparse_indices = tf.constant(
[[1 1]
 [2 1]
 [5 1]
 [1 2]
 [2 2]
 [5 2]
 [1 3]
 [2 3]
 [5 3]
 [1 2]
 [4 2]
 [6 2]
 [1 3]
 [4 3]
 [6 3]
 [2 2]
 [3 2]
 [6 2]
 [2 3]
 [3 3]
 [6 3]
 [2 2]
 [4 2]
 [4 2]])

我想要做的是完成以下简单的任务:

out[rows_tf, columns_tf] = changed_tensor

为此,我正在这样做:

out[rows_tf:column_tf].assign(changed_tensor)

但是,我收到此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/

这是预期的输出:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]
 [0.        0.        0.        0.       ]]

任何想法我该如何完成任务?

提前谢谢:)

1 个答案:

答案 0 :(得分:2)

此示例(从tf文档tf.scatter_nd_update here扩展而来)应该有帮助。

您要首先将row_indices和column_indices合并为2d索引列表,该列表是indices的{​​{1}}参数。然后,您输入了一个期望值列表,即tf.scatter_nd_update

updates
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])

update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  print sess.run(update)

专门针对您的数据

Result:

[[ 0.  0.  1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  2.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]