这段代码中numpy sum方法是否多余?

时间:2019-03-23 08:37:04

标签: numpy

我正在读一本书,发现以下错误:

def relu(x):
    return (x>0)*x

def relu2dev(x):
    return (x>0)

street_lights = np.array([[1,0,1],[0,1,1],[0,0,1],[1,1,1]])

walk_stop = np.array([[1,1,0,0]]).T

alpha = 0.2
hidden_size = 4

weights_0_1 = 2*np.random.random((3,hidden_size))-1
weights_1_2 = 2*np.random.random((hidden_size,1))-1

for it in range(60):
    layer_2_error = 0;

    for i in range(len(street_lights)):
        layer_0 = street_lights[i:i+1]
        layer_1 = relu(np.dot(layer_0,weights_0_1))
        layer_2 = np.dot(layer_1,weights_1_2)

        layer_2_delta = (layer_2-walk_stop[i:i+1])

        # -> layer_2_delta's shape is (1,1), so why np.sum?
        layer_2_error += np.sum((layer_2_delta)**2)

        layer_1_delta = layer_2_delta.dot(weights_1_2.T) * relu2dev(layer_1)

        weights_1_2 -= alpha * layer_1.T.dot(layer_2_delta)
        weights_0_1 -= alpha * layer_0.T.dot(layer_1_delta)

    if(it % 10 == 9):
        print("Error: " + str(layer_2_error))

错误位置用# ->注释:

layer_2_delta的形状为(1,1),那么为什么要使用np.sum?我认为np.sum可以删除,但不太确定,因为它来自一本书。

1 个答案:

答案 0 :(得分:1)

如您所说,layer_2_delta的形状为(1,1)。这意味着它是一个二维数组,其中包含一个元素:layer_2_delta = np.array([[X]])。但是,layer_2_error是标量。因此,您可以通过选择第一个索引(layer_2_delta[0,0])的值或将所有元素相加(在这种情况下只是一个元素)来从数组中获取标量。由于该书似乎使用“平方误差之和”,因此保留数组中每个元素均为平方的表示法,然后将所有这些都加起来(出于指示目的)似乎是很自然的:这将更为通用(例如,针对案例(其中该图层包含多个元素)比使用索引方法。但是你是对的,可能还有其他方法可以做到这一点:)。