如何从生成器创建的tf.data.Dataset返回具有多个功能的字典?

时间:2018-06-26 14:30:59

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我有一个示例数据集,如下所示:

feature_1    feature_2    label
4            5            1
4            3            1
4            6            2
...

我为每个功能(功能_1和功能_2)创建了一个tf.feature_column.embedding_column,因此我必须从我的train_input_fn返回一个功能字典,其中的键与功能名称相同。我的输入函数如下:

def train_input_fn(features, labels, output_types, output_shapes, batch_size, feature_names):
    """
    Provides the data pipeline for the training process.
    :param features: (numpy.array) A numpy array that holds the training features.
    :param labels: (numpy.array) A numpy array that holds the target variable.
    :param output_types: (tuple(tensorflow.DType)) A tuple containing the data type of each component yielded.
    :param output_shapes: (tuple(tensorflow.TensorShape)) A tuple containing the shape of each component yielded.
    :param batch_size: (int) The size of every batch.
    :return: (dict, int) A dictionary of key -> value for every feature and the target label.
    """
    def gen():
        for f, l in zip(features, labels):
            yield f, l

    ds = tf.data.Dataset.from_generator(gen, output_types, output_shapes)
    # If we do repeat without any argument we actually create and infinite loop.
    # That is preferred, we can now control the iterations via epochs.
    ds = ds.repeat().batch(batch_size)
    feature, label = ds.make_one_shot_iterator().get_next()

    return {'feature': feature}, label

如何返回类似的内容:

{'feature_1': x_1, 'feature_2': x_2}

1 个答案:

答案 0 :(得分:0)

这几处改动应该可以做到:

def train_input_fn(features, labels, output_types, output_shapes, batch_size, feature_names):
    """
    Provides the data pipeline for the training process.
    :param features: (numpy.array) A numpy array that holds the training features.
    :param labels: (numpy.array) A numpy array that holds the target variable.
    :param output_types: (tuple(tensorflow.DType)) A tuple containing the data type of each component yielded.
    :param output_shapes: (tuple(tensorflow.TensorShape)) A tuple containing the shape of each component yielded.
    :param batch_size: (int) The size of every batch.
    :return: (dict, int) A dictionary of key -> value for every feature and the target label.
    """
    def gen():
        for f, l in zip(features, labels):
            yield f, l

    ds = tf.data.Dataset.from_generator(gen, output_types, output_shapes)
    # If we do repeat without any argument we actually create and infinite loop.
    # That is preferred, we can now control the iterations via epochs.
    ds = ds.repeat().batch(batch_size)
    feature, label = ds.make_one_shot_iterator().get_next()

    return {'feature_1': feature[:, 0], 'feature_2': feature[:, 1]}, label