我正在构建一个简单的tensorflow模型,该模型需要即时存储/构建一个ID ==> INDEX映射,以便可以在实时构建的数据集上对其进行在线训练,即该数据集不存在预构建地图(否则我会)。
当然,我可以在张量流计算图的外部构建映射(并在输入模型输入之前将每个ID转换为INDEX),但我们真的希望在节省时间的前提下将模型自包含,以简化将来的预测/计算任务
我正在尝试为此使用tensorflow的MutableHashTable
,但一直从tensorflow遇到“ catchya的”。我已经确认可以通过
# defined in computation graph
id_to_idx = tf.contrib.lookup.MutableHashTable(tf.int32, tf.int32, MISSING_KEY_VALUE)
# .. with an open tf.Session
keys = [123, 234, 456, 567]
vals = [0,1,2,3]
sess.run(id_to_idx.insert(keys, vals))
sess.run(id_to_idx.lookup(keys))
# Returns [0,1,2,3]
我想做的是使它成为图形上的一个操作,以便识别新的键,分配增量值,然后将其插入哈希表以在模型的其余部分中使用。
我想要这样的东西,
# defined in computation graph
id_inputs = tf.placeholder(tf.int32, shape=(None,))
update_hashtable = id_to_idx.insert(new_keys, new_values)
indices = id_to_idx.lookup(id_inputs)
# .. with open tf.Session()
batch_of_somenew_and_someold_ids = [13,563,673,23]
sess.run([update_hashtable, indices], feed_dict={id_inputs: batch_of_somenew_and_someold_ids)
# Updates hash table values then returns them with next op
import tensorflow as tf
MISSING_KEY_VALUE = -1
id_inputs = tf.placeholder(tf.int32, shape=(None,))
id_to_idx = tf.contrib.lookup.MutableHashTable(tf.int32, tf.int32, MISSING_KEY_VALUE)
_next_idx = tf.Variable(0, dtype=tf.int32)
#
# Update Hash Table
#
_indices = id_to_idx.lookup(id_inputs) # initial lookup from hash
_not_seen_loc = tf.not_equal(_indices, MISSING_KEY_VALUE) # location in batch of IDs that haven't been seen
_ids_not_seen_yet = tf.where(_not_seen_loc, id_inputs, id_inputs) # list of the IDs that haven't been seen
_end_idx = _next_idx + tf.size(_not_seen_loc)
_new_indices = tf.range(_next_idx, _end_idx) # new hash tables values for these keys
# Operation to update the idx iterator for next batch of inputs
update_iter = _next_idx.assign(_end_idx)
# *should* be an operation that makes the insert
update_hashtable = id_to_idx.insert(_ids_not_seen_yet, _new_indices)
# Operation to get mapped IDs for batch
indices = id_to_idx.lookup(id_inputs)
# .. Graph continues and does ML stuff
运行此示例,
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# some mock input data pulled off a dataset obj
batch_ids = [34, 532, 5, 234, 6]
_, iter_val, batch_indices = sess.run([update_hashtable, update_iter indices], feed_dict={id_inputs: batch_ids})
产生错误
Fetch argument <function update_hashtable at 0x1384401e0> has invalid type <class 'function'>, must be a string or Tensor. (Can not convert a function into a Tensor or Operation.)
这似乎很奇怪,因为id_to_idx.insert(keys,vals)
应该返回一个操作(如文档和测试示例中所述)。
通过每16个批次打印一次结果,我已经确认idx迭代器可以正确递增。
[16,32,48,64,80,96,112,128,144,160,176]
MutableHashTable
在计算图内部不起作用吗? docs说输入可以是张量。
打印update_hashtable
的类型,为什么它不返回操作(如在查找情况下的?)?
print(type(update_hashtable))
print(type(indices))
# <class 'function'>
# <class 'tensorflow.python.framework.ops.Tensor'>