Udacity深度学习:作业1,第5部分

时间:2017-12-19 01:20:06

标签: numpy python-3.6

我正在开展Udacity深度学习课程并且正在完成第一项任务,问题5是您尝试计算测试集和训练集中的重复数量。 (或验证和培训等)

我看过其他人的答案,但由于各种原因我对他们不满意。例如,我尝试了某人基于哈希的解决方案。但我觉得返回的结果不太可能是正确的。

所以主要的想法是你有一个格式化为数组的图像数组。即你试图比较索引0上的两个三维数组。一个数组是训练数据集,它是200000行,每行包含一个2-D数组,它是图像的值。另一个是测试集,有10000行,每行包含一个图像的二维数组。目标是在测试集中找到匹配(现在,完全匹配很好)训练集中的一行的所有行。因为每一行都是'本身就是一个图像(这是一个二维数组),然后为了使这项工作更快,我必须能够将两个集合进行比较,作为每行的元素比较。

我制定了我自己相当简单的解决方案:

# Find duplicates
# Loop through validation/test set and find ones that are identical matrices to something in the training data
def find_duplicates(compare_set, compare_labels, training_set, training_labels):
    dup_count = 0
    duplicates = []
    for i in range(len(compare_set)):
        if i > 100: continue
        if i % 100 == 0:
            print("i: ", i)
        for j in range(len(training_set)):
            if compare_labels[i] == training_labels[j]:
                if np.array_equal(compare_set[i], training_set[j]):
                    duplicates.append((i,j))
                    dup_count += 1
    return dup_count, duplicates

#print(len(valid_dataset))
print(len(train_dataset))
valid_dup_count, duplicates = find_duplicates(valid_dataset, valid_labels, train_dataset, train_labels)
print(valid_dup_count)
print(duplicates)
#test_dups = find_duplicates(test_dataset, train_dataset)
#print(test_dups)

它只是继续" 100之后是因为单独需要很长时间。如果我试图将所有10,000行验证集与训练集进行比较,则需要永久。

我原则上喜欢我的解决方案,因为它不仅可以计算重复项,还可以获取存在哪些匹配项的列表。 (在我看过的其他解决方案上都缺少了一些东西。)这使我可以手动测试我是否找到了合适的解决方案。

我真正需要的是更快(即内置到Numpy)解决方案来比较像这样的矩阵矩阵。我和'isin'一起玩过。和'其中'但是我们还没弄明白如何使用这些来获得我之后的结果。有人能指出我正确的方向以获得更快的解决方案吗?

1 个答案:

答案 0 :(得分:1)

您应该能够使用np.all()compare_settraining_set的所有图片中的单个图片与单行代码进行比较。您可以在axis参数中提供多个轴作为元组,以检查行和列上的数组相等性,并遍历每个图像。然后np.where()可以为您提供所需的索引。

例如:

n_train = 50
n_validation = 10
h, w = 28, 28

training_set = np.random.rand(n_train, h, w)
validation_set = np.random.rand(n_validation, h, w)

# create some duplicates
training_set[5] = training_set[10]
validation_set[2] = training_set[10]
validation_set[8] = training_set[10]

duplicates = []
for i, img in enumerate(validation_set):
    training_dups = np.where(np.all(training_set == img, axis=(1, 2)))[0]
    for j in training_dups:
        duplicates.append((i, j))

print(duplicates)
[(2, 5), (2, 10), (8, 5), (8, 10)]

包含许多numpy函数np.all(),允许您指定要操作的轴。例如,假设你有两个数组

>>> A = np.array([[1, 2], [3, 4]])
>>> B = np.array([[1, 2], [5, 6]])
>>> A
array([[1, 2],
       [3, 4]])
>>> B
array([[1, 2],
       [5, 6]])

现在,AB具有相同的第一行,但第二行不同。如果我们检查它们的相等性

>>> A == B
array([[ True,  True],
       [False, False]], dtype=bool)

我们得到一个与AB形状相同的数组。但是如果我想要行的索引相等呢?那么在这种情况下,我们可以做的是只返回True如果所有行中的值(即每列中的值)是True& #39 ;.所以我们可以在等式检查后使用np.all(),并为它提供与列对应的轴。

>>> np.all(A == B, axis=1)
array([ True, False], dtype=bool)

所以这个结果让我们知道第一行在两个数组中是相等的,第二行并不都是相等的。然后我们可以使用np.where()

获取行索引
>>> np.where(np.all(A == B, axis=1))
(array([0]),)

所以我们在这里看到第0行,即A[0]B[0]相等。

现在在我提出的解决方案中,你有一个3D阵列而不是这些2D阵列。我们不关心单个是否相等,我们关心 all 列是否相等。因此,如上所述,让我们创建两个随机的5x5图像。我将抓住其中一个图像并检查两个图像阵列之间是否相等:

>>> imgs = np.random.rand(2, 5, 5)
>>> img = imgs[1]
>>> imgs == img
array([[[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True]]], dtype=bool)

所以很明显,第二个是正确的,但我想将所有那些True值减少到一个True值;我只想要与每个值相等的图像对应的索引。

如果我们使用axis=1

>>> np.all(imgs == img, axis=1)
array([[False, False, False, False, False],
       [ True,  True,  True,  True,  True]], dtype=bool)

如果每行中的所有列都相同,那么我们会为每个获取True。实际上,我们希望通过检查所有行的相等性来进一步减少这种情况。因此,我们可以将此结果输入np.all()并检查结果数组的行:

>>> np.all(np.all(imgs == img, axis=1), axis=1)
array([False,  True], dtype=bool)

这给了我们一个布尔值,其中imgs内的图像等于img,我们可以简单地用np.where()得到结果。但你实际上并不需要像这样两次致电np.all();相反,你可以在一个元组中提供多个轴,只需一步减少行和列:

>>> np.all(imgs == img, axis=(1, 2))
array([False,  True], dtype=bool)

这就是上面解决方案的作用。希望能搞清楚!