你好,有人可以一步一步地解释下面的代码是什么? 特别是部分类和重塑? TNX
def load_data():
train_dataset = h5py.File('datasets/train_catvnoncat.h5', "r")
train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels
test_dataset = h5py.File('datasets/test_catvnoncat.h5', "r")
test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels
classes = np.array(test_dataset["list_classes"][:]) # the list of classes
train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes
答案 0 :(得分:1)
大多数行只是从datasets
文件加载h5
。不需要np.array(...)
包装器。 test_dataset[name][:]
足以加载数组。
test_set_y_orig = test_dataset["test_set_y"][:]
test_dataset
是已打开的文件。 test_dataset["test_set_y"]
是该文件的dataset
。 [:]
将数据集加载到numpy
数组中。查看h5py
文档,了解有关加载dataset
的详细信息。
我从
推断出来train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
加载的数组为1d,形状为(n,)
,此重塑只是添加初始维度,使其成为(1,n)
。我会把它编码为
train_set_y_orig = train_set_y_orig[None,:]
但结果是一样的。
classes
数组没有什么特别之处(虽然它可能是一个字符串数组)。