DECODE_RAW TensorSliceDataset

时间:2019-06-22 15:57:05

标签: python tensorflow nlp tensorflow2.0

我正在复制TTS模型Deep Voice 3。 数据集为LJSpeech-1.1。我找到了一个github存储库(https://github.com/Kyubyong/deepvoice3),但是它是用我使用TF 2.0的早期tensorflow版本编写的。 在数据处理中,我需要将encode_raw函数应用于TensorSliceDataset的输出。 但是,我无法将Decode_raw函数应用于输出。 所以,我的问题是如何将解码_原始应用于TensorSliceDataset的输出?

我已将文本转换为尺寸为(13066,)的张量。 在原始回购中,他使用了tf.train.slice_input_producer。 对于TF 2.0,我正在使用tf.data.Dataset.from_tensor_slices将该张量转换为TensorSliceDataset。 之后,我无法将Tensor_raw应用于TensorSliceDataset。下面是代码

# old TF code
texts, mels, dones, mags = tf.train.slice_input_producer([texts, mels, dones, mags], shuffle = True)
# TF 2.0 code
texts = tf.convert_to_tensor(texts)
texts = tf.data.Dataset.from_tensor_slices(texts)
texts = tf.io.decode_raw(texts, tf.int32) # (None,)

1 个答案:

答案 0 :(得分:0)

您需要将解析函数应用于数据集对象。 代替这一行

texts = tf.io.decode_raw(texts, tf.int32) # (None,)`

使用

texts = texts.map(lambda x: tf.io.decode_raw(x, tf.int32))