Tensorflow和散点更新为零

时间:2019-03-14 11:20:44

标签: python tensorflow

我正在尝试构建诸如遮罩处理之类的步进功能,例如
如果这是张量

+---+---+---+---+---+
| 1 | 2 | 3 | 4 | 1 |
+---+---+---+---+---+
| 2 | 2 | 2 | 3 | 3 |
+---+---+---+---+---+
| 2 | 2 | 1 | 5 | 3 |
+---+---+---+---+---+

例如,我有一个可学的tf.variable

+---+
| 3 |
+---+
| 3 |
+---+
| 1 |
+---+

我希望网络生成下表(即,结束变量为零)

+---+---+---+---+---+
| 1 | 2 | 3 | 0 | 0 |
+---+---+---+---+---+
| 2 | 2 | 2 | 0 | 0 |
+---+---+---+---+---+
| 2 | 0 | 0 | 0 | 0 |
+---+---+---+---+---+

我尝试使用scatter_nd_update,但不知道如何生成updates变量 rand_slice = -1 * np.random.randint(5,size = 5) 打印(rand_slice)

ref = tf.Variable(np.ones((5,6)),dtype=tf.int32)
indices = tf.Variable(rand_slice,dtype=tf.int32)
updates = tf.zeros((1),dtype=tf.int32)
update = tf.scatter_nd_update(ref, indices, updates)

问题是如何将张量定义为零

0 个答案:

没有答案