我有一个tf.data.Dataset
,dataset
,它具有以下功能描述:
feature_description = {
'text_features': tf.FixedLenFeature([100], tf.int64),
'numeric_features': tf.FixedLenFeature([200], tf.float32),
'label': tf.FixedLenFeature([1], tf.int64),
}
我想从每个样本中检索仅包含标签的NumPy数组。通过执行以下操作,我可以获得完整的NumPy数组:
def load_dataset(dataset):
""" Load an entire tf dataset into memory
"""
max_elems = np.iinfo(np.int32).max
# Make a single batch out of the entire dataset and get that element
dataset = dataset.batch(max_elems)
dataset_tensors = tf.contrib.data.get_single_element(dataset)
# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
return sess.run(dataset_tensors)
但这会将完整的dataset
作为NumPy数组加载到内存中(并在笔记本电脑上导致OutOfMemoryError
)。我只想获取标签。
一个想法:也许我可以做类似的事情:
dataset = dataset.map(lambda x: x['label']
result = load_dataset(dataset)
?
有什么建议吗?