我可以从tensorflow 2.0数据集中获取numpy数组吗?

时间:2019-11-06 14:13:21

标签: tensorflow

我可以从numpy获得一个tensorflow dataset数组吗?在下面的示例中,我可以迭代并从每个numpy array中获取一个tensor。但是我可以直接从dataset获得它吗?

>>> X = tf.reshape(tf.range(2*3), (2, 3))
<tf.Tensor: id=33, shape=(2, 3), dtype=int32, numpy=
 array([[0, 1, 2],
        [3, 4, 5]], dtype=int32)>

>>> dataset = tf.data.Dataset.from_tensor_slices(X)
<TensorSliceDataset shapes: (3,), types: tf.int32>

>>> t = next(iter(dataset))
<tf.Tensor: id=40, shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)> 

>>> t.numpy()
array([0, 1, 2], dtype=int32)

1 个答案:

答案 0 :(得分:0)

一种可能的解决方案(请参见here

def dataset_to_numpy_util(dataset, N):
    dataset = dataset.unbatch().batch(N)
    for images, labels in dataset:
        numpy_images = images.numpy()
        numpy_labels = labels.numpy()
        break;  
    return numpy_images, numpy_labels