为什么'tf.scatter_nd_update'不支持字符串?

时间:2019-07-11 07:15:05

标签: python tensorflow

在我的模型中,我需要维护一个非常长的二维变量张量,它具有多列和多行,其dtype是字符串。在每个训练步骤中,我需要更新该张量的几行。 'tf.scatter_nd_update'完全满足我的要求,但它不支持字符串。有什么解决方法可以解决它?

Traceback (most recent call last):
  File "tensorflow/python/client/session.py", line 1278, in _do_call
    return fn(*args)
  File "tensorflow/python/client/session.py", line 1261, in _run_fn
    self._extend_graph()
  File "tensorflow/python/client/session.py", line 1295, in _extend_graph
    tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'ScatterNdUpdate' with these attrs. Registered devices: [CPU], Registered kernels:
  device='CPU'; T in [DT_COMPLEX128]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_COMPLEX128]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_COMPLEX64]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_COMPLEX64]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_DOUBLE]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_DOUBLE]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_FLOAT]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_FLOAT]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_BFLOAT16]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_BFLOAT16]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_HALF]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_HALF]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_INT8]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_INT8]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_UINT8]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_UINT8]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_INT16]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_INT16]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_UINT16]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_UINT16]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_INT32]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_INT32]; Tindices in [DT_INT32]
  device='CPU'; T in [DT_INT64]; Tindices in [DT_INT64]
  device='CPU'; T in [DT_INT64]; Tindices in [DT_INT32]

         [[Node: ScatterNdUpdate = ScatterNdUpdate[T=DT_STRING, Tindices=DT_INT64, _class=["loc:@Variable"], use_locking=true](Variable, HashCollectiveAndUpdate, HashCollectiveAndUpdate:1)]]

1 个答案:

答案 0 :(得分:0)

令人惊讶的是,tf.scatter_nd_update不能与字符串一起使用,尤其是因为tf.scatter_nd可以工作。您可以使用以下函数来重现相同的行为:

import tensorflow as tf

def my_scatter_nd_update(ref, indices, updates, use_locking=True):
    indices = tf.convert_to_tensor(indices)
    updates = tf.convert_to_tensor(updates)
    # Make a mask for elements to replace
    m = tf.ones_like(updates, dtype=tf.bool)
    s = tf.shape(ref)
    mask = tf.scatter_nd(indices, m, s)
    # Make tensor of replacement values put in place
    upd_scatter = tf.scatter_nd(indices, updates, s)
    # Select replacement values for replaced positions
    new_value = tf.where(mask, upd_scatter, ref)
    # Do assignment
    return tf.assign(ref, new_value, use_locking=use_locking)

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    var = tf.Variable([['a', 'b', 'c'], ['d', 'e', 'f']])
    var_upd = my_scatter_nd_update(var, [[0, 1], [1, 2]], ['g', 'h'])
    sess.run(var.initializer)
    print(sess.run(var_upd))
    # [[b'a' b'g' b'c']
    #  [b'd' b'e' b'h']]

与适当的更新操作相比,这还需要做更多的工作,但是结果应该是相同的。