python3 ValueError:形状(4,1)和(4,3)未对齐:1(暗淡1)!= 4(暗淡0)

时间:2017-09-20 18:14:03

标签: python python-3.x numpy machine-learning

我正在编写一个Planar数据分类程序,其中有一个来自Coursera课程的隐藏层。

这段代码应该做什么以及为什么它不起作用?

def backward_propagation(parameters, cache, X, Y):
    """
    Implement the backward propagation using the instructions above.

    """
    m = X.shape[1]

    # First, retrieve W1 and W2 from the dictionary "parameters".
    ### START CODE HERE ### (≈ 2 lines of code)
    W1 = parameters["W1"]
    W2 = parameters["W2"]
    ### END CODE HERE ###

    # Retrieve also A1 and A2 from dictionary "cache".
    ### START CODE HERE ### (≈ 2 lines of code)
    A1 = cache["A1"]
    A2 = cache["A1"]
    ### END CODE HERE ###

    # Backward propagation: calculate dW1, db1, dW2, db2. 
    ### START CODE HERE ### (≈ 6 lines of code, corresponding to 6 equations on slide above)
    dZ2= A2-Y
    dW2 = (1/m)*np.dot(dZ2,A1.T)
    db2 = (1/m)*np.sum(dZ2, axis=1, keepdims=True)
    dZ1 = np.multiply(np.dot(W2.T, dZ2),1 - np.power(A1, 2))
    dW1 = (1 / m) * np.dot(dZ1, X.T)
    db1 = (1/m)*np.sum(dZ1,axis1,keepdims=True)
    ### END CODE HERE ###

    grads = {"dW1": dW1,
             "db1": db1,
             "dW2": dW2,
             "db2": db2}

    return grads
    parameters, cache, X_assess, Y_assess =      backward_propagation_test_case()

    grads = backward_propagation(parameters, cache, X_assess, Y_assess)
    print ("dW1 = "+ str(grads["dW1"]))
    print ("db1 = "+ str(grads["db1"]))
    print ("dW2 = "+ str(grads["dW2"]))
    print ("db2 = "+ str(grads["db2"]))

当我运行此代码时,我收到此错误:

ValueError: shapes (4,1) and (4,3) not aligned: 1 (dim 1) != 4 (dim 0)

4 个答案:

答案 0 :(得分:4)

当乘以两个矩阵时,即np.dot。第一个矩阵的列和第二个矩阵的行应该相等。这就是numpy抛出错误的原因。您不能将4x1矩阵与4x3矩阵相乘。

答案 1 :(得分:0)

您已经初始化了A2=cache["A1"],但是它应该是A2=cache["A2"]

答案 2 :(得分:0)

尝试重新整理您的指标 您的代码:

dW1 = (1 / m) * np.dot(dZ1, X.T)

尝试此代码:

enter code here
dW1 = (1 / m) * np.dot( X.T,dZ1)

答案 3 :(得分:0)

dZ1 = np.multiply(np.dot(W2.T, dZ2),1替换为dZ1 = np.multiply(np.dot(W2.T, dZ2),1),即在末尾添加np.multiply的右括号。并且也将A2=cache["A1"]替换为A2=cache["A2"]