除了在动态矩形中,Tensor等于零

时间:2018-03-11 15:28:01

标签: python tensorflow

鉴于a, b, c, d类型tf.int和形状[1]的{​​{1}} 4个张量,获得张量X的最简单方法是:

  • X的形状为[h, w]
  • X0行和列a < b之间的c < d之外,其他地方都是1

1 个答案:

答案 0 :(得分:1)

您可以使用tf.meshgrid创建行索引和列索引数组。然后在索引数组上应用logical operations以获取应该在哪里的掩码。最后,tf.where可用于构建请求的张量X

示例:

import tensorflow as tf
h = 5
w = 6

a = 1
b = 3
c = 2
d = 4

cols, rows = tf.meshgrid(tf.range(w), tf.range(h))
mask_rows = tf.logical_and( tf.less(rows, b), tf.greater_equal(rows, a))
mask_cols = tf.logical_and( tf.less(cols, d), tf.greater_equal(cols, c))
mask = tf.logical_and(mask_rows, mask_cols)

X = tf.where(mask, tf.ones([h,w], tf.float32), tf.zeros([h,w], tf.float32))

验证输出:

sess = tf.Session()
print(sess.run(cols))
print(sess.run(rows))
print(sess.run(X))

cols的输出:

[[0 1 2 3 4 5]
 [0 1 2 3 4 5]
 [0 1 2 3 4 5]
 [0 1 2 3 4 5]
 [0 1 2 3 4 5]]

rows

的输出
[[0 0 0 0 0 0]
 [1 1 1 1 1 1]
 [2 2 2 2 2 2]
 [3 3 3 3 3 3]
 [4 4 4 4 4 4]]

X

的输出
[[0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 1. 0. 0.]
 [0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]