我应该如何通过tf.data.Dataset将元素附加到每个序列数据

时间:2018-05-22 09:36:42

标签: tensorflow tensorflow-datasets

我想通过tf.data.Dataset添加char2int [' EOS']来获取序列数据。 我写的代码如下:

import tensorflow as tf 

def _get_generator(list_of_text, char2int):
    def gen():
        for text in list_of_text:
            yield [char2int[x] for x in text] # transform char to int
    return gen

def get_dataset(list_of_text, char2int):
    gen = _get_generator(list_of_text, char2int)
    dataset = tf.data.Dataset.from_generator(gen, (tf.int32), tf.TensorShape([None]))

    dataset = dataset.map(lambda seq: seq+[char2int['EOS']])  # append EOS to the end of line

    data_iter = dataset.make_initializable_iterator()

    return dataset, data_iter

char2int = {'EOS':1, 'a':2, 'b':3, 'c':4}
list_of_text = ['aaa', 'abc'] # the sequence data

with tf.Graph().as_default():
    dataset, data_iter = get_dataset(list_of_text, char2int)
    with tf.Session() as sess:
        sess.run(data_iter.initializer)
        tt1 = sess.run(data_iter.get_next())
        tt2 = sess.run(data_iter.get_next())
        print(tt1)  # got [3 3 3] but I want [2 2 2 1]
        print(tt2)  # god [3 4 5] but I want [2 3 4 1]

但我无法得到我想要的东西。它对每个数据执行元素添加。我该怎么办呢,谢谢

2 个答案:

答案 0 :(得分:0)

在地图功能中,您将每个值加1而不是连接值。您可以将_get_generator更改为:

def _get_generator(list_of_text, char2int):
   def gen():
     for text in list_of_text:
        yield [char2int[x] for x in text] + [char2int['EOS']]# transform char to int
   return gen

并删除dataset.map来电。

答案 1 :(得分:0)

正如Vijay在his answer中指出的那样,+类型为tf.Tensor的{​​{1}}运算符执行添加而不是连接。要将其他符号连接到序列的末尾,请使用tf.int32中的tf.concat()

Dataset.map()