我在使用tf.contrib.data.Dataset
API时遇到了困难,并想知道是否有人可以提供帮助。我想将word2vec的整个skip-gram
预处理转换为这个范例,以便稍微使用API,它涉及以下操作:
Stream
开始(要理解为Scala的方式,所有数据都不在内存中,但在加载时加载需要访问令牌序列:seq_tokens
。seq_tokens
中的任何一个中,我们使用python函数提取skip-gram,该函数返回元组(token, context)
的列表。token
s列选择要素,并为context
s列标注。在伪代码中,为了使其更清晰,它将如上所示。我们应该利用框架并行系统不要自己加载数据,所以我会做一些事情,比如首先在内存中加载序列的索引,然后加载序列(在map
内部,因此如果不是全部的话)这些行是同步处理的,数据是异步加载的,并且没有OOM可以担心),并对那些会产生需要展平的不同数量的skip-gram的令牌序列应用一个函数。最后,我会正式结束data
形状(#lines=number of skip-grams generated, #columns=2)
。
data = range(1:N)
.map(i => load(i): Seq[String]) // load: Int -> Seq[String] loads dynamically a sequence of tokens (sequences have varying length)
.flat_map(s => skip_gram(s)) // skip_gram: Seq[String] -> Seq[(String, String)] with output length
features = data[0] // features
lables = data[1] // labels
我已经尝试用Dataset
的API来天真地这样做,但是我被卡住了,我可以这样做:
iterator = (
tf.contrib.data.Dataset.range(N)
.map(lambda i: tf.py_func(load_data, [i], [tf.int32, tf.int32])) // (1)
.flat_map(?) // (2)
.make_one_shot_iterator()
)
(1)TensorFlow在这里不高兴,因为加载的序列有不同的长度...
(2)还没有管理到跳过部分...我实际上只是想调用一个python函数来计算skip-gram的序列(可变大小)并将其展平以便如果return类型是一个矩阵,那么每一行应该被理解为输出Dataset
的新行。
如果有人有任何想法,非常感谢,如果我忘记提及有用的精确度,请不要犹豫......
答案 0 :(得分:2)
我只是在实施同样的事情;这是我如何解决它:
dataset = tf.data.TextLineDataset(filename)
if mode == ModeKeys.TRAIN:
dataset = dataset.shuffle(buffer_size=batch_size * 100)
dataset = dataset.flat_map(lambda line: string_to_skip_gram(line))
dataset = dataset.batch(batch_size)
在我的数据集中,我将每一行都视为独立的,所以我并不担心跨越多行的上下文。
因此,我通过函数string_to_skip_gram
对每一行进行平面映射,该函数返回Dataset
长度,该长度取决于行中的标记数。
string_to_skip_gram
将该行转换为一系列令牌,使用tokenize_str
以ID(使用方法tf.py_func
)表示:
def string_to_skip_gram(line):
def handle_line(line):
token_ids = tokenize_str(line)
(features, labels) = skip_gram(token_ids)
return np.array([features, labels], dtype=np.int64)
res = tf.py_func(handle_line, [line], tf.int64)
features = res[0]
labels = res[1]
return tf.data.Dataset.from_tensor_slices((features, labels))
最后,skip_gram
返回所有可能的上下文单词和目标单词的列表:
def skip_gram(token_ids):
skip_window = 1
features = []
labels = []
context_range = [i for i in range(-skip_window, skip_window + 1) if i != 0]
for word_index in range(skip_window, len(token_ids) - skip_window):
for context_word_offset in context_range:
features.append(token_ids[word_index])
labels.append(token_ids[word_index + context_word_offset])
return features, labels
请注意,我不是在这里采样上下文单词;只使用它们。