使用Estimator API对大型字符串进行RNN训练

时间:2018-08-19 15:15:51

标签: python tensorflow rnn tensorflow-estimator

我们希望估算器像智能手机中的键盘一样预测单词或下一个作品。我们想在一些文本文件上对其进行训练。

所以我们继续研究tensorflow API,发现

estimator = RNNEstimator(
    head=tf.contrib.estimator.regression_head(),
    sequence_feature_columns=[token_emb],
    rnn_cell_fn=rnn_cell_fn)

这似乎是为RNN创建估算器的便捷方法。现在,我们在功能列方面面临问题。我们正在像这样设置它们

token_sequence = sequence_categorical_column_with_hash_bucket(
    key="text", hash_bucket_size=num_of_categories, dtype=tf.string)
token_emb = embedding_column(categorical_column=token_sequence, 
    dimension=8)

在我们的输入函数中定义了'text'

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"text": features},
    y=labels,
    batch_size=batch_size,
    num_epochs=None,
    shuffle=True)

其中features只是从原始文本中抽取的一长串40个字符的序列。

问题

  1. 是否可以在字符串输入上使用要素列?该文档并没有真正释放太多。
  2. 如何处理标签?目前,我们收到一个错误,因为它们从未转换为整数
  3. 即使在引入任意整数作为标签时,调用estimator.train(input_fn=train_input_fn, steps=100)时也会出现错误

      

    '给定类型:{}'。format(type(features)))ValueError:功能应   是Tensor s的字典。给定类型:

所以我们绝对在这里做错了。任何帮助表示赞赏:)

1 个答案:

答案 0 :(得分:0)

有一个简短的示例,将字符串单词特征作为SparseTensors传递到StateSavingRnnEstimator unit tests中的一步一步标签from pydub import AudioSegment import matplotlib.pyplot as plt import numpy as np import wave import sys (即语言建模)中。看起来您正在尝试做的事情差不多了,但有一点要注意,该估算器已过时;可以从中汲取灵感并定义自己的document.addEventListener("DOMContentLoaded", (event) => { page.getPropsAsync(request).then((props:any) => { const layoutElement = <Layout pageProps={props}/> hydrate(layoutElement, document.getElementById("root-div")) }) })