TensorFlow,数据集API和flat_map操作

时间:2017-08-07 13:47:58

标签: python tensorflow

我在使用tf.contrib.data.Dataset API时遇到了困难,并想知道是否有人可以提供帮助。我想将word2vec的整个skip-gram预处理转换为这个范例,以便稍微使用API​​,它涉及以下操作:

  1. 动态加载令牌序列(以避免一次加载内存中的所有数据集),比如说我们从Stream开始(要理解为Scala的方式,所有数据都不在内存中,但在加载时加载需要访问令牌序列:seq_tokens
  2. 从这些seq_tokens中的任何一个中,我们使用python函数提取skip-gram,该函数返回元组(token, context)的列表。
  3. token s列选择要素,并为context s列标注。
  4. 在伪代码中,为了使其更清晰,它将如上所示。我们应该利用框架并行系统不要自己加载数据,所以我会做一些事情,比如首先在内存中加载序列的索引,然后加载序列(在map内部,因此如果不是全部的话)这些行是同步处理的,数据是异步加载的,并且没有OOM可以担心),并对那些会产生需要展平的不同数量的skip-gram的令牌序列应用一个函数。最后,我会正式结束data形状(#lines=number of skip-grams generated, #columns=2)

    data = range(1:N)
      .map(i => load(i): Seq[String]) // load: Int -> Seq[String] loads dynamically a sequence of tokens (sequences have varying length)
      .flat_map(s => skip_gram(s)) // skip_gram: Seq[String] -> Seq[(String, String)] with output length
    features = data[0] // features
    lables = data[1] // labels
    

    我已经尝试用Dataset的API来天真地这样做,但是我被卡住了,我可以这样做:

    iterator = (
      tf.contrib.data.Dataset.range(N)
      .map(lambda i: tf.py_func(load_data, [i], [tf.int32, tf.int32])) // (1)
      .flat_map(?) // (2)
      .make_one_shot_iterator()
    )
    

    (1)TensorFlow在这里不高兴,因为加载的序列有不同的长度...

    (2)还没有管理到跳过部分...我实际上只是想调用一个python函数来计算skip-gram的序列(可变大小)并将其展平以便如果return类型是一个矩阵,那么每一行应该被理解为输出Dataset的新行。

    如果有人有任何想法,非常感谢,如果我忘记提及有用的精确度,请不要犹豫......

1 个答案:

答案 0 :(得分:2)

我只是在实施同样的事情;这是我如何解决它:

dataset = tf.data.TextLineDataset(filename)

if mode == ModeKeys.TRAIN:
  dataset = dataset.shuffle(buffer_size=batch_size * 100)

dataset = dataset.flat_map(lambda line: string_to_skip_gram(line))

dataset = dataset.batch(batch_size)

在我的数据集中,我将每一行都视为独立的,所以我并不担心跨越多行的上下文。

因此,我通过函数string_to_skip_gram对每一行进行平面映射,该函数返回Dataset长度,该长度取决于行中的标记数。

string_to_skip_gram将该行转换为一系列令牌,使用tokenize_str以ID(使用方法tf.py_func)表示:

def string_to_skip_gram(line):
  def handle_line(line):
    token_ids = tokenize_str(line)

    (features, labels) = skip_gram(token_ids)

    return np.array([features, labels], dtype=np.int64)

  res = tf.py_func(handle_line, [line], tf.int64)

  features = res[0]
  labels = res[1]

  return tf.data.Dataset.from_tensor_slices((features, labels))

最后,skip_gram返回所有可能的上下文单词和目标单词的列表:

def skip_gram(token_ids):
  skip_window = 1

  features = []
  labels = []

  context_range = [i for i in range(-skip_window, skip_window + 1) if i != 0]

  for word_index in range(skip_window, len(token_ids) - skip_window):
    for context_word_offset in context_range:
      features.append(token_ids[word_index])
      labels.append(token_ids[word_index + context_word_offset])

  return features, labels

请注意,我不是在这里采样上下文单词;只使用它们。