在我的模型中,我需要维护一个非常长的二维变量张量,它具有多列和多行,其dtype是字符串。在每个训练步骤中,我需要更新该张量的几行。 'tf.scatter_nd_update'完全满足我的要求,但它不支持字符串。有什么解决方法可以解决它?
Traceback (most recent call last):
File "tensorflow/python/client/session.py", line 1278, in _do_call
return fn(*args)
File "tensorflow/python/client/session.py", line 1261, in _run_fn
self._extend_graph()
File "tensorflow/python/client/session.py", line 1295, in _extend_graph
tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'ScatterNdUpdate' with these attrs. Registered devices: [CPU], Registered kernels:
device='CPU'; T in [DT_COMPLEX128]; Tindices in [DT_INT64]
device='CPU'; T in [DT_COMPLEX128]; Tindices in [DT_INT32]
device='CPU'; T in [DT_COMPLEX64]; Tindices in [DT_INT64]
device='CPU'; T in [DT_COMPLEX64]; Tindices in [DT_INT32]
device='CPU'; T in [DT_DOUBLE]; Tindices in [DT_INT64]
device='CPU'; T in [DT_DOUBLE]; Tindices in [DT_INT32]
device='CPU'; T in [DT_FLOAT]; Tindices in [DT_INT64]
device='CPU'; T in [DT_FLOAT]; Tindices in [DT_INT32]
device='CPU'; T in [DT_BFLOAT16]; Tindices in [DT_INT64]
device='CPU'; T in [DT_BFLOAT16]; Tindices in [DT_INT32]
device='CPU'; T in [DT_HALF]; Tindices in [DT_INT64]
device='CPU'; T in [DT_HALF]; Tindices in [DT_INT32]
device='CPU'; T in [DT_INT8]; Tindices in [DT_INT64]
device='CPU'; T in [DT_INT8]; Tindices in [DT_INT32]
device='CPU'; T in [DT_UINT8]; Tindices in [DT_INT64]
device='CPU'; T in [DT_UINT8]; Tindices in [DT_INT32]
device='CPU'; T in [DT_INT16]; Tindices in [DT_INT64]
device='CPU'; T in [DT_INT16]; Tindices in [DT_INT32]
device='CPU'; T in [DT_UINT16]; Tindices in [DT_INT64]
device='CPU'; T in [DT_UINT16]; Tindices in [DT_INT32]
device='CPU'; T in [DT_INT32]; Tindices in [DT_INT64]
device='CPU'; T in [DT_INT32]; Tindices in [DT_INT32]
device='CPU'; T in [DT_INT64]; Tindices in [DT_INT64]
device='CPU'; T in [DT_INT64]; Tindices in [DT_INT32]
[[Node: ScatterNdUpdate = ScatterNdUpdate[T=DT_STRING, Tindices=DT_INT64, _class=["loc:@Variable"], use_locking=true](Variable, HashCollectiveAndUpdate, HashCollectiveAndUpdate:1)]]
答案 0 :(得分:0)
令人惊讶的是,tf.scatter_nd_update
不能与字符串一起使用,尤其是因为tf.scatter_nd
可以工作。您可以使用以下函数来重现相同的行为:
import tensorflow as tf
def my_scatter_nd_update(ref, indices, updates, use_locking=True):
indices = tf.convert_to_tensor(indices)
updates = tf.convert_to_tensor(updates)
# Make a mask for elements to replace
m = tf.ones_like(updates, dtype=tf.bool)
s = tf.shape(ref)
mask = tf.scatter_nd(indices, m, s)
# Make tensor of replacement values put in place
upd_scatter = tf.scatter_nd(indices, updates, s)
# Select replacement values for replaced positions
new_value = tf.where(mask, upd_scatter, ref)
# Do assignment
return tf.assign(ref, new_value, use_locking=use_locking)
# Test
with tf.Graph().as_default(), tf.Session() as sess:
var = tf.Variable([['a', 'b', 'c'], ['d', 'e', 'f']])
var_upd = my_scatter_nd_update(var, [[0, 1], [1, 2]], ['g', 'h'])
sess.run(var.initializer)
print(sess.run(var_upd))
# [[b'a' b'g' b'c']
# [b'd' b'e' b'h']]
与适当的更新操作相比,这还需要做更多的工作,但是结果应该是相同的。