有没有办法在Tensorflow中基于字符串张量执行字典查找?
在普通的Python中,我会做类似
的事情value = dictionary[key]
。现在我想在Tensorflow运行时做同样的事情,当我将key
作为String张量时。像
value_tensor = tf.dict_lookup(string_tensor)
会很好。
答案 0 :(得分:23)
您可能会发现tensorflow.contrib.lookup
有用:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py
https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable
特别是,您可以这样做:
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
)
out = table.lookup(input_tensor)
table.init.run()
print out.eval()
答案 1 :(得分:1)
如果要使用新的TF 2.0代码运行此程序,并且默认情况下启用急切执行。下面是快速代码段。
import tensorflow as tf
# build a lookup table
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=tf.constant([0, 1, 2, 3]),
values=tf.constant([10, 11, 12, 13]),
),
default_value=tf.constant(-1),
name="class_weight"
)
# now let us do a lookup
input_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])
out = table.lookup(input_tensor)
print(out)
输出:
tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)
答案 2 :(得分:0)
tf.gather可以为您提供帮助,但只能获取list的值。您可以将字典转换为键和值列表,然后应用tf.gather。示例:
# Your dict
dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
# concrete query
query_list = ['a', 'c']
# unpack key and value lists
key, value = list(zip(*dict_.items()))
# map query list to list -> [0, 2]
query_list = [i for i, s in enumerate(key) if s in query_list]
# query as tensor
query = tf.placeholder(tf.int32, shape=[None])
# convert value list to tensor
vl_tf = tf.constant(value)
# get value
my_vl = tf.gather(vl_tf, query)
# session run
sess = tf.InteractiveSession()
sess.run(my_vl, feed_dict={query:query_list})
答案 3 :(得分:-10)
TensorFlow是一种数据流语言,不支持张量以外的数据结构。没有地图或字典类型。但是,根据您的需要,当您使用Python包装器时,可以在驱动程序进程中维护一个字典,该字典在Python中执行,并使用它与TensorFlow图形执行进行交互。例如,您可以在会话中执行TensorFlow图的一个步骤,将字符串值返回给Python驱动程序,将其用作驱动程序中字典的键,并使用检索到的值来确定要请求的下一个计算来自会议。如果这些字典查找的速度对性能至关重要,这可能不是一个好的解决方案。