如何获得tf.data.Dataset的特征和张量的字典?

时间:2019-11-27 15:16:38

标签: python tensorflow tensorflow-datasets

给出以下数据集:

import pandas as pd
import tensorflow as tf

df = pd.DataFrame({
    'feat_binomial': [5, 1, 7, 4, 6],
    'feat_normal': [5.001512, 5.346654, -0.480363,4.821558,-2.080958],
    'feat_ordinal': ['low', 'low', 'low','low','low'],
    'feat_string': ['a', 'b', 'b','b','b'],
})

dataset = tf.data.Dataset.from_tensor_slices(dict(df))

我想得到特征和张量的字典,将map函数应用于数据集时,会得到它。在这里可以看到它已打印:

def print_input_dict(input_dict):
    features_tensors = dict(input_dict)
    print(features_tensors)
    return features_tensors

dataset.map(print_input_dict)
{'feat_binomial': <tf.Tensor 'args_0:0' shape=() dtype=int32>,
 'feat_normal': <tf.Tensor 'args_1:0' shape=() dtype=float64>,
 'feat_ordinal': <tf.Tensor 'args_2:0' shape=() dtype=string>,
 'feat_string': <tf.Tensor 'args_3:0' shape=() dtype=string>}

为了获得这本词典,我设法做到了以下几点:

def _extract_labels(input_dict):
    global features_tensors
    features_tensors = dict(input_dict)
    return features_tensors

dataset.map(print_input_dict)

,以便将其存储在我以后可以访问的全局变量中,但这并不像是正确的方法。还有其他获取这本词典的方法吗?

谢谢

1 个答案:

答案 0 :(得分:1)

您可以尝试使用迭代器,

iterator = dataset.make_one_shot_iterator()
feature_dict = iterator.get_next() # Returns Dictionary of features