Tensorflow:MutableHashTable InvalidArgumentError,支持的数据类型

时间:2018-01-23 10:42:45

标签: tensorflow hashtable lookup

我想用int64键/值创建一个MutableHashTable:

import tensorflow as tf
with tf.Session() as sess:
    keys = tf.range(10,dtype=tf.int64)
    vals = tf.range(10,dtype=tf.int64)
    table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1)
    table.insert(keys, vals)
    print(sess.run(table.lookup(tf.range(20,dtype=tf.int64))))

但是当我执行它时,我收到以下错误消息:

InvalidArgumentError (see above for traceback): No OpKernel was registered to support Op 'MutableHashTableV2' with these attrs.  Registered devices: [CPU,GPU], Registered kernels:
  device='CPU'; key_dtype in [DT_STRING]; value_dtype in [DT_FLOAT]
  device='CPU'; key_dtype in [DT_STRING]; value_dtype in [DT_INT64]
  device='CPU'; key_dtype in [DT_INT64]; value_dtype in [DT_STRING]
  device='CPU'; key_dtype in [DT_STRING]; value_dtype in [DT_BOOL]
  device='CPU'; key_dtype in [DT_INT64]; value_dtype in [DT_FLOAT]

     [[Node: MutableHashTable_16 = MutableHashTableV2[container="", key_dtype=DT_INT64, shared_name="", use_node_name_sharing=true, value_dtype=DT_INT64]()]]

如果我使用HashTable,它可以工作:

import tensorflow as tf
with tf.Session() as sess:
    keys = tf.range(10,dtype=tf.int64)
    vals = tf.range(10,dtype=tf.int64)
    table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(keys,vals),-1)
    table.init.run()
    print(sess.run(table.lookup(tf.range(20,dtype=tf.int64))))

1 个答案:

答案 0 :(得分:1)

对于tensorflow中的可变哈希表,只允许以下键 - 值对类型:

key_type   -   value_type  
tf.string  -   tf.float  
tf.string  -   tf.int64  
tf.int64   -   tf.string  
tf.string  -   tf.bool  
tf.int64   -   tf.float  

您在错误消息中也提到了这一点。

使用in64作为键和值的一种方法是使用MutableDenseHashTable

以下是执行此操作的示例代码:

import tensorflow as tf

with tf.Session() as sess:
         # Initialize keys and values.
         keys = tf.constant([1, 2, 3], dtype=tf.int64)
         vals = tf.constant([1, 2, 3], dtype=tf.int64)

         # Initialize hash table.
         table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1, empty_key=0)

         # Insert values to hash table and run the op.
         insert_op = table.insert(keys, vals)
         sess.run(insert_op)

         # Print hash table lookups.
         print(sess.run(table.lookup(keys)))