我为NLP任务运行LSTM模型,并具有附加功能来获取output_layer的特定元素(在连接它们和softmax之前),这些元素的位置/索引通过匹配input_tensor中的word_id(单词已嵌入到ID中)找到。搜索功能是tf.where()。
问题是“从training_loop记录的错误:编译失败:尝试编译图形时检测到不支持的操作...位置(与XLA_TPU_JIT设备兼容的未注册'Where'OpKernel”
我还在这里发现了类似的问题:https://github.com/tensorflow/tpu/issues/236
我的问题是“就我而言,有什么方法可以解决该问题?”
我的附加功能如下。感谢您的帮助。
def get_entity_tensor(input_ten, output_ten, sign=999):
sign = tf.constant([sign], dtype=tf.int32) # input_ids has tf.int32 type
p = tf.where(tf.equal(input_ten, sign)) # tf.where return tf.int64 type
ent = tf.gather_nd(output_ten, p)
return ent