像pandas_df.shape一样,tensorflow.data.Dataset有什么办法吗? 谢谢。
答案 0 :(得分:1)
我对内置的东西不熟悉,但是可以从Dataset._tensors
属性中检索形状。示例:
import tensorflow as tf
def dataset_shapes(dataset):
try:
return [x.get_shape().as_list() for x in dataset._tensors]
except TypeError:
return dataset._tensors.get_shape().as_list()
和用法:
from sklearn.datasets import make_blobs
x_train, y_train = make_blobs(n_samples=10,
n_features=2,
centers=[[1, 1], [-1, -1]],
cluster_std=0.5)
dataset = tf.data.Dataset.from_tensor_slices(x_train)
print(dataset_shapes(dataset)) # [10, 2]
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
print(dataset_shapes(dataset)) # [[10, 2], [10]]
答案 1 :(得分:0)
要添加到Vlad's answer中,以防万一有人尝试通过tfds下载的数据集进行尝试,一种可能的方法是使用数据集信息:
info.features['image'].shape # shape of 1 feature in dataset
info.features['label'].num_classes # number of classes
info.splits['train'].num_examples # number of training examples
例如。 tf_flowers:
import tensorflow as tf
import tensorflow_datasets as tfds
dataset, info = tfds.load("tf_flowers", with_info=True) # download data with info
image_size = info.features['image'].shape # (None, None, 3)
num_classes = info.features['label'].num_classes # 5
data_size = info.splits['train'].num_examples # 3670
例如。 fashion_mnist:
import tensorflow as tf
import tensorflow_datasets as tfds
dataset, info = tfds.load("fashion_mnist", with_info=True) # download data with info
image_size = info.features['image'].shape # (28, 28, 1)
num_classes = info.features['label'].num_classes # 10
data_splits = {k:v.num_examples for k,v in info.splits.items()} # {'test': 10000, 'train': 60000}
希望这会有所帮助。