Tensorflow数据集:检索Just标签的Numpy数组

时间:2019-03-11 23:55:20

标签: python numpy tensorflow

我有一个tf.data.Datasetdataset,它具有以下功能描述:

    feature_description = {
        'text_features': tf.FixedLenFeature([100], tf.int64),
        'numeric_features': tf.FixedLenFeature([200], tf.float32),
        'label': tf.FixedLenFeature([1], tf.int64),
    }

我想从每个样本中检索仅包含标签的NumPy数组。通过执行以下操作,我可以获得完整的NumPy数组:

def load_dataset(dataset):
    """ Load an entire tf dataset into memory
    """
    max_elems = np.iinfo(np.int32).max

    # Make a single batch out of the entire dataset and get that element
    dataset = dataset.batch(max_elems)
    dataset_tensors = tf.contrib.data.get_single_element(dataset)

    # Create a session and evaluate `whole_dataset_tensors` to get arrays.
    with tf.Session() as sess:
        return sess.run(dataset_tensors)

但这会将完整的dataset作为NumPy数组加载到内存中(并在笔记本电脑上导致OutOfMemoryError)。我只想获取标签。

一个想法:也许我可以做类似的事情:

dataset = dataset.map(lambda x: x['label'] result = load_dataset(dataset)

有什么建议吗?

0 个答案:

没有答案