TF数据集是否有tf.gather()等效项?

时间:2019-01-17 02:14:50

标签: python tensorflow machine-learning h5py

我目前正在尝试使用从h5py文件加载的预先计算的词嵌入。嵌入是针对数据集中的每个示例预先计算的,因此我正尝试通过其示例/序列ID检索嵌入。但是,嵌入内容很大,因此我遇到了这样的问题:我无法直接在嵌入内容上运行tf.gather()来获取我想要的内容,因为从TF中检索出来后, t产生大于2GB的张量。结果,我正在尝试使用以下代码:

  # precompute_ds is just the tensor of word embeddings
  precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
  precompute_place = tf.placeholder(precompute_ds.dtype, 
                                    shape=precompute_ds.shape)
  word_emb = tf.gather(precompute_place, sequence_ids)
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(word_emb, feed_dict={precompute_place: precompute_ds})

return word_emb

但是,由于precompute_ds是一个h5py数据集,所以我不确定如何为其初始化迭代器并得到以下错误:

FailedPreconditionError (see above for traceback): GetNext() failed
because the iterator has not been initialized. Ensure that you have run
the initializer operation for this iterator before getting the next element.
     [[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=
[[?], [?], [?]], output_types=[DT_INT64, DT_INT64, DT_INT64], 
_device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]

因此,我也尝试在TF网站上的this example之后使用以下代码:

  precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
  precompute_place = tf.placeholder(precompute_ds.dtype, 
                                    shape=precompute_ds.shape)
  ds = tf.data.Dataset.from_tensor_slices(precompute_place)
  word_emb = tf.gather(ds, sequence_ids)
  it = ds.make_initializable_iterator()
  with tf.Session() as sess:
    sess.run(it.initializer, feed_dict={precompute_place: precompute_ds})

return word_emb

但是,这有两个问题:首先,我很确定,即使tf.gather在TF数据集上工作了,也无法正确填充word_emb。我现在想的是,我可以使用第二种方法正确地填充ds,但是我不知道如何准确地获得该特定批次的sequence_ids。对于这两种方法中的任一种,是否有任何建议可以实现这一目的?

谢谢!

0 个答案:

没有答案