我有一个形状为(300,)
的权重张量,它具有二进制数据,一些元素为1's
,另一些元素为0's
。我需要将除1's
的第一次出现以外的所有1
转换为0's
。我很想知道在Tensorflow中执行此操作的简单方法。
但是,这是我目前正在尝试实现的目标:
我使用以下代码行获得张量为1的所有索引:
indices = tf.squeeze(tf.where(tf.greater(weights, 0)))
然后,我构建一个张量,该张量将在相应的索引上更新:
updates = tf.constant(0., shape=indices[1:].eval(session=sess2).shape, dtype=tf.float32)
然后,我使用scatter_update
在相应的 indices 上更新 updates ,但是由于scatter_update
仅适用于变量,因此我创建了一个变量并将可更新张量分配给该变量,如下所示:
weights_var = tf.Variable(tf.zeros(weights.get_shape()), name="weights_var")
tf.assign(weights_var, weights)
然后,我打tf.scatter_update
:
tf.scatter_update(weights_subset, indices[1:], updates).eval(session=sess2)
这给了我以下错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-211-c863dff9ffc7> in <module>()
51 updates = tf.constant(0., shape=indices[1:].eval(session=sess2).shape, dtype=tf.float32)
52
---> 53 tf.scatter_update(weights_subset, indices[1:], updates).eval(session=sess2)
54
55 # print(final_weights.eval(session=sess2))
~/anaconda2/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
290 to use the updated values after the update is done.
291 """
--> 292 if ref.dtype._is_ref_dtype:
293 return gen_state_ops.scatter_update(ref, indices, updates,
294 use_locking=use_locking, name=name)
AttributeError: 'numpy.dtype' object has no attribute '_is_ref_dtype'
我很想知道这个问题的解决方案,如果可能的话,可以在Tensorflow中使用一种更简单的矢量化单线。谢谢:-)
答案 0 :(得分:2)
如果我了解您的问题,那么此代码流应根据您问题的第一段工作。不知道它是否可以进一步缩短。
mask = tf.Variable([0, 1, 1, 0, 1, 1, 1, 1])
indices = tf.squeeze(tf.where(tf.greater(mask, 0)))
sess.run(tf.global_variables_initializer())
valuesofindices = np.delete(indices.eval(session=sess),
0)
update = tf.scatter_update(mask,
valuesofindices,
tf.tile(tf.constant([0],
tf.int32),
valuesofindices.shape))
sess.run(tf.global_variables_initializer())
print(update.eval(session=sess))
输出:
[0 1 0 0 0 0 0 0]