我可以从numpy
获得一个tensorflow dataset
数组吗?在下面的示例中,我可以迭代并从每个numpy array
中获取一个tensor
。但是我可以直接从dataset
获得它吗?
>>> X = tf.reshape(tf.range(2*3), (2, 3))
<tf.Tensor: id=33, shape=(2, 3), dtype=int32, numpy=
array([[0, 1, 2],
[3, 4, 5]], dtype=int32)>
>>> dataset = tf.data.Dataset.from_tensor_slices(X)
<TensorSliceDataset shapes: (3,), types: tf.int32>
>>> t = next(iter(dataset))
<tf.Tensor: id=40, shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>
>>> t.numpy()
array([0, 1, 2], dtype=int32)
答案 0 :(得分:0)
一种可能的解决方案(请参见here)
def dataset_to_numpy_util(dataset, N):
dataset = dataset.unbatch().batch(N)
for images, labels in dataset:
numpy_images = images.numpy()
numpy_labels = labels.numpy()
break;
return numpy_images, numpy_labels