鉴于a, b, c, d
类型tf.int
和形状[1]
的{{1}} 4个张量,获得张量X
的最简单方法是:
X
的形状为[h, w]
,X
除0
行和列a < b
之间的c < d
之外,其他地方都是1
。答案 0 :(得分:1)
您可以使用tf.meshgrid
创建行索引和列索引数组。然后在索引数组上应用logical operations以获取应该在哪里的掩码。最后,tf.where
可用于构建请求的张量X
。
示例:
import tensorflow as tf
h = 5
w = 6
a = 1
b = 3
c = 2
d = 4
cols, rows = tf.meshgrid(tf.range(w), tf.range(h))
mask_rows = tf.logical_and( tf.less(rows, b), tf.greater_equal(rows, a))
mask_cols = tf.logical_and( tf.less(cols, d), tf.greater_equal(cols, c))
mask = tf.logical_and(mask_rows, mask_cols)
X = tf.where(mask, tf.ones([h,w], tf.float32), tf.zeros([h,w], tf.float32))
验证输出:
sess = tf.Session()
print(sess.run(cols))
print(sess.run(rows))
print(sess.run(X))
cols
的输出:
[[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]]
rows
[[0 0 0 0 0 0]
[1 1 1 1 1 1]
[2 2 2 2 2 2]
[3 3 3 3 3 3]
[4 4 4 4 4 4]]
X
[[0. 0. 0. 0. 0. 0.]
[0. 0. 1. 1. 0. 0.]
[0. 0. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0.]]