在张量流中为seq2seq解码器的csv中的单独列创建输入/输出序列

时间:2017-08-10 02:50:44

标签: python tensorflow sequence-to-sequence

我正在尝试使用张量流seq2seq 而且我很难想出一个很好的方法来添加" GO"," EOS" +" PAD"元素一系列标签。我正在使用tf.TextLineReader从.csv读取此数据,我创建的.csv有一个Text列,然后是每个顺序标签的4列。

以下是我创建sample_input.csv的示例csv: "这是我们想要的示例文本消息",Label1,Label2,Label3, "这是我们要加载的另一句话",Label10 ,,,

这是我的代码,它读取了这个csv:

import tensorflow as tf
tf.reset_default_graph()

_BATCH_SIZE = 2

sess = tf.InteractiveSession()

filename_queue = tf.train.string_input_producer(['sample_input.csv'])
reader = tf.TextLineReader(skip_header_lines=1)

_, rows = reader.read_up_to(filename_queue, num_records=_BATCH_SIZE)

row_columns = tf.expand_dims(rows, -1)
Text, Label1, Label2, Label3, Label4 = tf.decode_csv(
    row_columns, record_defaults=[[""],[""],[""],[""],[""]])

start = tf.constant(["START"] * _BATCH_SIZE)
start = tf.expand_dims(start, -1)
input_seq = tf.string_join(
    inputs = [start, Label1, Label2, Label3],separator = ", ")

output_seq = tf.string_join(
    inputs = [Label1, Label2, Label3, Label4],separator = ", ")
features = tf.stack([Text, input_seq, output_seq])

sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners()
features.eval()

上面的示例将在features.eval():

上打印出以下内容
array([[[b'This is an example text message that we want to have'],
    [b'Here is another sentences that we want to load']],

    [[b'START, Label1, Label2, Label3'],[b'START, Label10, , ']],

    [[b'Label1, Label2, Label3, '], [b'Label10, , , ']]], dtype=object)

现在我知道这不是创建这些序列的正确位置,但我希望得到一些关于如何正确创建序列的建议。这些4个标签的序列长度不一,有些可能只有1个,而其他可能有4个。理想情况下,我的输入最终会被

单一标签:

  • decoder_input = [GO,Label1,PAD,PAD]
  • decoder_output = [Label1,END,PAD,PAD]

双重标签:

  • decoder_input = [GO,Label1,Label2,PAD]
  • decoder_output = [Label1,Label2,END,PAD]

三个标签:

  • decoder_input = [GO,Label1,Label2,Label3]
  • decoder_output = [Label1,Label2,Label3,END]

四个标签:

  • decoder_input = [GO,Label1,Label2,Label3]
  • decoder_output = [Label1,Label2,Label3,Label4] *注意:最后一个没有结束序列,因为它已经是4个元素长

有人能提出一个更好的方法来从csv中的四个独立列创建解码器输入/输出吗?

0 个答案:

没有答案