我想通过tf.data.Dataset添加char2int [' EOS']来获取序列数据。 我写的代码如下:
import tensorflow as tf
def _get_generator(list_of_text, char2int):
def gen():
for text in list_of_text:
yield [char2int[x] for x in text] # transform char to int
return gen
def get_dataset(list_of_text, char2int):
gen = _get_generator(list_of_text, char2int)
dataset = tf.data.Dataset.from_generator(gen, (tf.int32), tf.TensorShape([None]))
dataset = dataset.map(lambda seq: seq+[char2int['EOS']]) # append EOS to the end of line
data_iter = dataset.make_initializable_iterator()
return dataset, data_iter
char2int = {'EOS':1, 'a':2, 'b':3, 'c':4}
list_of_text = ['aaa', 'abc'] # the sequence data
with tf.Graph().as_default():
dataset, data_iter = get_dataset(list_of_text, char2int)
with tf.Session() as sess:
sess.run(data_iter.initializer)
tt1 = sess.run(data_iter.get_next())
tt2 = sess.run(data_iter.get_next())
print(tt1) # got [3 3 3] but I want [2 2 2 1]
print(tt2) # god [3 4 5] but I want [2 3 4 1]
但我无法得到我想要的东西。它对每个数据执行元素添加。我该怎么办呢,谢谢
答案 0 :(得分:0)
在地图功能中,您将每个值加1而不是连接值。您可以将_get_generator
更改为:
def _get_generator(list_of_text, char2int):
def gen():
for text in list_of_text:
yield [char2int[x] for x in text] + [char2int['EOS']]# transform char to int
return gen
并删除dataset.map
来电。
答案 1 :(得分:0)
正如Vijay在his answer中指出的那样,+
类型为tf.Tensor
的{{1}}运算符执行添加而不是连接。要将其他符号连接到序列的末尾,请使用tf.int32
中的tf.concat()
:
Dataset.map()