我目前正在将现有代码重构为更新的TF Dataset API。在我们当前的过程中,我们使用产品ID填充标准python字典以分类ID。
现在我已经将我们的图像/路径移动到TF数据集,然后使用tf.string_split我从文件名本身提取各种信息。其中一个是product_id。此时,product_id是一个tf张量,我无法使用以前的方法通过" 执行查找,如果product_to_class中的product_id "因为我现在有一个张量,我无法通过标准词典进行搜索。
所以我使用这个项目来学习如何提高性能。所以我想知道"最佳/推荐"在使用tf Dataset API批处理时,请采用此方法。我是否将product_id转换为字符串,只是通过上面的检查执行查找,或者我现在是否将products_to_class字典转换为另一个数据结构(如另一个数据集)并使用张量执行查找?任何建议都将不胜感激。
我目前的一个小例子是:
prod_to_class = {'12345': 0, '67890': 1}
#Below logic is in a mapped function used on a TF.Dataset
def _parse_fn(filename, label)
core_file = tf.string_split([filename], '\\').values[-1]
product_id = tf.string_split([core_file], ".").values[0]
#unable to perform below because product_id is now a tensor and
#products_to_class is a python dictionary
if product_id in products_to_class:
label = products_to_class[product_id]
答案 0 :(得分:2)
用于执行此操作的内置TensorFlow机制是使用tf.contrib.lookup
表。例如,如果您有要映射到密集整数的字符串键列表,则可以在_parse_fn()
之外定义以下内容:
# This constructor creates a lookup table that implicitly maps each string in the
# argument to its index in the list (e.g. '67890' -> 1).
products_to_class = tf.contrib.lookup.index_table_from_tensor(['12345', '67890'])
...然后在_parse_fn()
中使用products_to_class.lookup()
。
def _parse_fn(filename, label):
core_file = tf.string_split([filename], '\\').values[-1]
product_id = tf.string_split([core_file], ".").values[0]
# Returns a `tf.Tensor` that corresponds to the value associated with
# `product_id` in the `products_to_class` table.
label = products_to_class.lookup(product_id)
# ...
请注意,这会对您的程序施加两个额外的限制:
Dataset.make_initializable_iterator()
代替Dataset.make_one_shot_iterator()
。sess.run(tf.tables_initializer())
。如果您使用高级tf.estimator
API并从tf.data.Dataset
返回input_fn
,则会处理这两项内容。