通过feature_columns将自由文本功能添加到具有数据集API的Tensorflow预设估算器中

时间:2018-04-16 15:40:38

标签: tensorflow google-cloud-ml tensorflow-datasets tensorflow-estimator

我正在尝试构建一个提供reddit_score = f('subreddit','comment')

的模型

主要是作为一个例子,我可以为工作项目进行构建。

我的代码是here

我的问题是我看到了罐装估算器,例如DNNLinearCombinedRegressor必须包含属于FeatureColumn类的feature_columns。

我有我的词汇文件并且知道如果我只是限制评论的第一个字我可以做类似的事情

tf.feature_column.categorical_column_with_vocabulary_file(
        key='comment',
        vocabulary_file='{}/vocab.csv'.format(INPUT_DIR)
        )

但是,如果我传入评论中的前10个单词,那么我不确定如何从"this is a pre padded 10 word comment xyzpadxyz xyzpadxyz"这样的字符串转到feature_column,这样我就可以构建嵌入传递到广泛而深入的模型中的deep要素。

似乎它必须是非常明显或简单的东西但是我不能在生活中找到这个特定设置的任何现有示例(canned wide and deep,dataset api,以及各种功能,例如subreddit和raw文字功能,如评论)。

我甚至考虑自己进行词汇整数查询,这样我传入的comment功能就像[23,45,67,12,1,345,7,99,999,999]那样也许我可以通过带有形状的numeric_feature获取它,然后从那里做一些事情。但这感觉有点奇怪。

2 个答案:

答案 0 :(得分:2)

您可以使用tf.string_split(),然后执行tf.slice()来对其进行切片,首先注意tf.pad()字符串为零。查看标题预处理操作:  https://towardsdatascience.com/how-to-do-text-classification-using-tensorflow-word-embeddings-and-cnn-edae13b3e575

一旦有了单词,就可以创建十个要素列

答案 1 :(得分:0)

根据来自@Lak的帖子的方法添加答案,但是对数据集api进行了一些调整。

# Create an input function reading a file using the Dataset API
# Then provide the results to the Estimator API
def read_dataset(prefix, mode, batch_size):

    def _input_fn():

        def decode_csv(value_column):

            columns = tf.decode_csv(value_column, field_delim='|', record_defaults=DEFAULTS)
            features = dict(zip(CSV_COLUMNS, columns))

            features['comment_words'] = tf.string_split([features['comment']])
            features['comment_words'] = tf.sparse_tensor_to_dense(features['comment_words'], default_value=PADWORD)
            features['comment_padding'] = tf.constant([[0,0],[0,MAX_DOCUMENT_LENGTH]])
            features['comment_padded'] = tf.pad(features['comment_words'], features['comment_padding'])
            features['comment_sliced'] = tf.slice(features['comment_padded'], [0,0], [-1, MAX_DOCUMENT_LENGTH])
            features['comment_words'] = tf.pad(features['comment_sliced'], features['comment_padding'])
            features['comment_words'] = tf.slice(features['comment_words'],[0,0],[-1,MAX_DOCUMENT_LENGTH])

            features.pop('comment_padding')
            features.pop('comment_padded')
            features.pop('comment_sliced')

            label = features.pop(LABEL_COLUMN)

            return features, label

        # Use prefix to create file path
        file_path = '{}/{}*{}*'.format(INPUT_DIR, prefix, PATTERN)

        # Create list of files that match pattern
        file_list = tf.gfile.Glob(file_path)

        # Create dataset from file list
        dataset = (tf.data.TextLineDataset(file_list)  # Read text file
                    .map(decode_csv))  # Transform each elem by applying decode_csv fn

        tf.logging.info("...dataset.output_types={}".format(dataset.output_types))
        tf.logging.info("...dataset.output_shapes={}".format(dataset.output_shapes))

        if mode == tf.estimator.ModeKeys.TRAIN:

            num_epochs = None # indefinitely
            dataset = dataset.shuffle(buffer_size = 10 * batch_size)

        else:

            num_epochs = 1 # end-of-input after this

        dataset = dataset.repeat(num_epochs).batch(batch_size)

        return dataset.make_one_shot_iterator().get_next()

    return _input_fn

然后在下面的函数中,我们可以引用我们在decode_csv()

的一部分中创建的字段
# Define feature columns
def get_wide_deep():

    EMBEDDING_SIZE = 10

    # Define column types
    subreddit = tf.feature_column.categorical_column_with_vocabulary_list('subreddit', ['news', 'ireland', 'pics'])

    comment_embeds = tf.feature_column.embedding_column(
        categorical_column = tf.feature_column.categorical_column_with_vocabulary_file(
            key='comment_words',
            vocabulary_file='{}/vocab.csv-00000-of-00001'.format(INPUT_DIR),
            vocabulary_size=100
            ),
        dimension = EMBEDDING_SIZE
        )

    # Sparse columns are wide, have a linear relationship with the output
    wide = [ subreddit ]

    # Continuous columns are deep, have a complex relationship with the output
    deep = [ comment_embeds ]

    return wide, deep