通过Tensorflow中的数据集API处理批处理时,在字典中执行索引查找的推荐方法是什么?

时间:2018-05-03 14:41:44

标签: tensorflow tensorflow-datasets

我目前正在将现有代码重构为更新的TF Dataset API。在我们当前的过程中,我们使用产品ID填充标准python字典以分类ID。

现在我已经将我们的图像/路径移动到TF数据集,然后使用tf.string_split我从文件名本身提取各种信息。其中一个是product_id。此时,product_id是一个tf张量,我无法使用以前的方法通过" 执行查找,如果product_to_class中的product_id "因为我现在有一个张量,我无法通过标准词典进行搜索。

所以我使用这个项目来学习如何提高性能。所以我想知道"最佳/推荐"在使用tf Dataset API批处理时,请采用此方法。我是否将product_id转换为字符串,只是通过上面的检查执行查找,或者我现在是否将products_to_class字典转换为另一个数据结构(如另一个数据集)并使用张量执行查找?任何建议都将不胜感激。

我目前的一个小例子是:

prod_to_class = {'12345': 0, '67890': 1}

#Below logic is in a mapped function used on a TF.Dataset
def _parse_fn(filename, label)
  core_file = tf.string_split([filename], '\\').values[-1]
  product_id = tf.string_split([core_file], ".").values[0]

  #unable to perform below because product_id is now a tensor and
  #products_to_class is a python dictionary
  if product_id in products_to_class:
    label = products_to_class[product_id]

1 个答案:

答案 0 :(得分:2)

用于执行此操作的内置TensorFlow机制是使用tf.contrib.lookup表。例如,如果您有要映射到密集整数的字符串键列表,则可以在_parse_fn()之外定义以下内容:

# This constructor creates a lookup table that implicitly maps each string in the
# argument to its index in the list (e.g. '67890' -> 1).
products_to_class = tf.contrib.lookup.index_table_from_tensor(['12345', '67890'])

...然后在_parse_fn()中使用products_to_class.lookup()

def _parse_fn(filename, label):
  core_file = tf.string_split([filename], '\\').values[-1]
  product_id = tf.string_split([core_file], ".").values[0]

  # Returns a `tf.Tensor` that corresponds to the value associated with 
  # `product_id` in the `products_to_class` table.
  label = products_to_class.lookup(product_id)

  # ...

请注意,这会对您的程序施加两个额外的限制:

  1. 您必须使用Dataset.make_initializable_iterator()代替Dataset.make_one_shot_iterator()
  2. 在开始使用输入管道中的元素之前,必须先调用sess.run(tf.tables_initializer())
  3. 如果您使用高级tf.estimator API并从tf.data.Dataset返回input_fn,则会处理这两项内容。