如何在计算图中动态地插入`MutableHashTable`?

时间:2019-02-11 23:55:32

标签: python tensorflow insert hashtable

型号要求

我正在构建一个简单的tensorflow模型,该模型需要即时存储/构建一个ID ==> INDEX映射,以便可以在实时构建的数据集上对其进行在线训练,即该数据集不存在预构建地图(否则我会)。

当然,我可以在张量流计算图的外部构建映射(并在输入模型输入之前将每个ID转换为INDEX),但我们真的希望在节省时间的前提下将模型自包含,以简化将来的预测/计算任务

方法

我正在尝试为此使用tensorflow的MutableHashTable,但一直从tensorflow遇到“ catchya的”。我已经确认可以通过

将新的键/值项插入哈希表
# defined in computation graph
id_to_idx = tf.contrib.lookup.MutableHashTable(tf.int32, tf.int32, MISSING_KEY_VALUE)

# .. with an open tf.Session

keys = [123, 234, 456, 567]
vals = [0,1,2,3]
sess.run(id_to_idx.insert(keys, vals))

sess.run(id_to_idx.lookup(keys))
# Returns [0,1,2,3]

我想做的是使它成为图形上的一个操作,以便识别新的键,分配增量值,然后将其插入哈希表以在模型的其余部分中使用。

我想要这样的东西,

# defined in computation graph
id_inputs = tf.placeholder(tf.int32, shape=(None,))
update_hashtable = id_to_idx.insert(new_keys, new_values)
indices = id_to_idx.lookup(id_inputs)

# .. with open tf.Session()

batch_of_somenew_and_someold_ids = [13,563,673,23]
sess.run([update_hashtable, indices], feed_dict={id_inputs: batch_of_somenew_and_someold_ids)
# Updates hash table values then returns them with next op

最小示例

import tensorflow as tf

MISSING_KEY_VALUE = -1

id_inputs = tf.placeholder(tf.int32, shape=(None,))
id_to_idx = tf.contrib.lookup.MutableHashTable(tf.int32, tf.int32, MISSING_KEY_VALUE)
_next_idx = tf.Variable(0, dtype=tf.int32)

#
# Update Hash Table
#
_indices = id_to_idx.lookup(id_inputs) # initial lookup from hash
_not_seen_loc = tf.not_equal(_indices, MISSING_KEY_VALUE) # location in batch of IDs that haven't been seen
_ids_not_seen_yet = tf.where(_not_seen_loc, id_inputs, id_inputs) # list of the IDs that haven't been seen
_end_idx = _next_idx + tf.size(_not_seen_loc) 
_new_indices = tf.range(_next_idx, _end_idx) # new hash tables values for these keys

# Operation to update the idx iterator for next batch of inputs
update_iter = _next_idx.assign(_end_idx)

# *should* be an operation that makes the insert
update_hashtable = id_to_idx.insert(_ids_not_seen_yet, _new_indices)

# Operation to get mapped IDs for batch
indices = id_to_idx.lookup(id_inputs)

# .. Graph continues and does ML stuff

运行此示例,

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())

  # some mock input data pulled off a dataset obj
  batch_ids = [34, 532, 5, 234, 6]

  _, iter_val, batch_indices = sess.run([update_hashtable, update_iter indices], feed_dict={id_inputs: batch_ids})

产生错误

Fetch argument <function update_hashtable at 0x1384401e0> has invalid type <class 'function'>, must be a string or Tensor. (Can not convert a function into a Tensor or Operation.)

这似乎很奇怪,因为id_to_idx.insert(keys,vals)应该返回一个操作(如文档和测试示例中所述)。

通过每16个批次打印一次结果,我已经确认idx迭代器可以正确递增。

[16,32,48,64,80,96,112,128,144,160,176]

问题

  1. MutableHashTable在计算图内部不起作用吗? docs说输入可以是张量。

  2. 打印update_hashtable的类型,为什么它不返回操作(如在查找情况下的?)?

print(type(update_hashtable))
print(type(indices))

# <class 'function'>
# <class 'tensorflow.python.framework.ops.Tensor'>
  1. 我错误地解决了这个问题吗?是否有已知/更好的解决方案?

0 个答案:

没有答案