我正在尝试将doc2vec example从使用feed-dict转换为使用tf.data管道。我正在努力实现窗口序列生成。
我已经成功实现了text_helpers.py中所有数据方法的加载,但是我不知道如何解决generate_batch_data
方法。任何指针将不胜感激。
我尝试使用tf.string.split将句子拆分为单词,但是随后我无法获得结果的形状以用于设置要遍历ws
的范围。(请参见下文)错误,因为它是None
而不是整数。我假设append
变量window_sequences
的行也行不通。
更新:将我的代码添加到github仓库中 https://github.com/waldren/doc2vec-datapipeline
更新:我想我朝着正确的方向前进。我现在有一个tf.while_loop
def generate_window_sequence(self, sentence_ix, sentence, label):
ws = tf.strings.split([sentence], sep=' ').values
offset = tf.constant(int(self.window_size/2))
i = tf.constant(int(self.window_size/2))
def while_condition (i, window_sequences, ws, offset):
return tf.less(i, tf.math.subtract(tf.size(ws), offset))
def body(i, window_sequences, ws, offset):
#TODO - move within offset if it is > 1
before = tf.math.subtract(i, offset)
after = tf.math.add(i, offset)
#window_sequences.append((tf.gather(ws, [before, i])))
#window_sequences.append((tf.gather(ws, [after, i])))
return i+1, window_sequences, ws, offset
window_sequences = []
window_sequences.append(("before1","test1"))
i, window_sequences, ws, offset = tf.while_loop(while_condition, body, [i, window_sequences, ws, offset])
return sentence_ix, window_sequences, label
这将“运行”,但我不知道循环是否正在运行多次。但是,一旦取消注释body
函数中的两行(以window_sequences开头),就会出现ValueError: Number of inputs and outputs of body must match loop_vars: 5, 7
错误。