使用Scatter_Update更新张量

时间:2018-08-05 11:32:31

标签: python-3.x tensorflow deep-learning

我有一个形状为(300,)的权重张量,它具有二进制数据,一些元素为1's,另一些元素为0's。我需要将除1's的第一次出现以外的所有1转换为0's。我很想知道在Tensorflow中执行此操作的简单方法。

但是,这是我目前正在尝试实现的目标:

我使用以下代码行获得张量为1的所有索引:

indices = tf.squeeze(tf.where(tf.greater(weights, 0)))

然后,我构建一个张量,该张量将在相应的索引上更新:

updates = tf.constant(0., shape=indices[1:].eval(session=sess2).shape, dtype=tf.float32)

然后,我使用scatter_update在相应的 indices 上更新 updates ,但是由于scatter_update仅适用于变量,因此我创建了一个变量并将可更新张量分配给该变量,如下所示:

 weights_var = tf.Variable(tf.zeros(weights.get_shape()), name="weights_var")
tf.assign(weights_var, weights)

然后,我打tf.scatter_update

tf.scatter_update(weights_subset, indices[1:], updates).eval(session=sess2)

这给了我以下错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-211-c863dff9ffc7> in <module>()
     51 updates = tf.constant(0., shape=indices[1:].eval(session=sess2).shape, dtype=tf.float32)
     52 
---> 53 tf.scatter_update(weights_subset, indices[1:], updates).eval(session=sess2)
     54 
     55 # print(final_weights.eval(session=sess2))

~/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
    290     to use the updated values after the update is done.
    291   """
--> 292   if ref.dtype._is_ref_dtype:
    293     return gen_state_ops.scatter_update(ref, indices, updates,
    294                                         use_locking=use_locking, name=name)

AttributeError: 'numpy.dtype' object has no attribute '_is_ref_dtype'

我很想知道这个问题的解决方案,如果可能的话,可以在Tensorflow中使用一种更简单的矢量化单线。谢谢:-)

1 个答案:

答案 0 :(得分:2)

如果我了解您的问题,那么此代码流应根据您问题的第一段工作。不知道它是否可以进一步缩短。

mask = tf.Variable([0, 1, 1, 0, 1, 1, 1, 1])

indices = tf.squeeze(tf.where(tf.greater(mask, 0)))
sess.run(tf.global_variables_initializer())

valuesofindices = np.delete(indices.eval(session=sess),
                            0)

update = tf.scatter_update(mask,
                           valuesofindices,
                           tf.tile(tf.constant([0],
                                   tf.int32),
                           valuesofindices.shape))

sess.run(tf.global_variables_initializer())
print(update.eval(session=sess))

输出:

  

[0 1 0 0 0 0 0 0]