首先,这是我的代码:
"""Softmax."""
scores = [3.0, 1.0, 0.2]
import numpy as np
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
num = np.exp(x)
score_len = len(x)
y = np.array([0]*score_len)
sum_n = np.sum(num)
#print sum_n
for index in range(1,score_len):
y[index] = (num[index])/sum_n
return y
print(softmax(scores))
错误出现在以下行:
y[index] = (num[index])/sum_n
我用以下代码运行代码:
# Plot softmax curves
import matplotlib.pyplot as plt
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
plt.plot(x, softmax(scores).T, linewidth=2)
plt.show()
这到底出了什么问题?
答案 0 :(得分:2)
只需将print
语句编辑为“调试器”即可显示正在发生的事情:
import numpy as np
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
num = np.exp(x)
score_len = len(x)
y = np.array([0]*score_len)
sum_n = np.sum(num)
#print sum_n
for index in range(1,score_len):
print((num[index])/sum_n)
y[index] = (num[index])/sum_n
return y
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
softmax(scores).T
打印
[ 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504]
所以你试图将这个数组分配给另一个数组的一个元素。这是不允许的!
有几种方法可以实现它。只是改变
y = np.array([0]*score_len)
到多维数组可以工作:
y = np.zeros(score.shape)
应该这样做,但我不确定这是不是你想要的。
修改强>
似乎你不想要多维输入,所以你只需要改变:
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
到
scores = np.hstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
通过打印验证这些数组的形状scores.shape
确实可以帮助您自己找到这些错误。第一个沿着第一个轴(vstack)堆叠,而hstack由第零个轴(这是你想要的)堆叠。
答案 1 :(得分:1)
这是一种初始化数组的坏方法:
y = np.array([0]*score_len)
更好地做一些像
这样的事情y = np.zeros((n,m))
其中n
和m
是最终产品的2个维度。我从你的另一个问题中假设你希望y
为2d(毕竟你之后做了.T
)。
注意传递给函数的scores
形状。迭代时,请包含:
。它可以是可选的,但你需要它来保持尺寸直接在你自己的想法:
y[index,:] = (num[index,:])/sum_n
总之 - 专注于理解如何使用多维数组 - 如何创建它们,如何索引它们,如何在不迭代的情况下使用它们,以及如何在需要时正确迭代。
答案 2 :(得分:0)
这应该可以完美而快速地运作
scores = [3.0, 1.0, 0.2]
import numpy as np
def softmax(x):
num = np.exp(x)
score_len = len(x)
y = np.zeros(score_len, object) # or => np.asarray([None]*score_len)
sum_n = np.sum(num)
for i in range(score_len):
y[i] = num[i] / sum_n
return y
print(softmax(scores))
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
printout = softmax(scores).T
print(printout)
<强>输出:强>
[0.8360188027814407 0.11314284146556011 0.050838355752999158]
[ array([ 3.26123038e-05, 3.60421698e-05, 3.98327578e-05,
4.40220056e-05, 4.86518403e-05, 5.37685990e-05,
5.94234919e-05, 6.56731151e-05, 7.25800169e-05,
8.02133239e-05, 8.86494329e-05, 9.79727751e-05,
1.08276662e-04, 1.19664218e-04, 1.32249413e-04,
1.46158206e-04, 1.61529798e-04, 1.78518035e-04,
1.97292941e-04, 2.18042421e-04, 2.40974142e-04,
2.66317614e-04, 2.94326482e-04, 3.25281069e-04,
3.59491177e-04, 3.97299194e-04, 4.39083515e-04,
4.85262332e-04, 5.36297817e-04, 5.92700751e-04,
6.55035633e-04, 7.23926331e-04, 8.00062328e-04,
8.84205618e-04, 9.77198335e-04, 1.07997118e-03,
1.19355274e-03, 1.31907978e-03, 1.45780861e-03,
1.61112768e-03, 1.78057146e-03, 1.96783579e-03,
2.17479489e-03, 2.40352006e-03, 2.65630048e-03,
2.93566604e-03, 3.24441273e-03, 3.58563059e-03,
3.96273465e-03, 4.37949910e-03, 4.84009504e-03,
5.34913227e-03, 5.91170543e-03, 6.53344491e-03,
7.22057331e-03, 7.97996764e-03, 8.81922816e-03,
9.74675448e-03, 1.07718296e-02, 1.19047128e-02,
1.31567424e-02, 1.45404491e-02, 1.60696814e-02,
1.77597446e-02, 1.96275532e-02, 2.16918010e-02,
2.39731477e-02, 2.64944256e-02, 2.92808687e-02,
3.23603645e-02, 3.57637337e-02, 3.95250385e-02,
4.36819230e-02, 4.82759910e-02, 5.33532213e-02,
5.89644285e-02, 6.51657716e-02, 7.20193157e-02,
7.95936532e-02, 8.79645908e-02])
array([ 0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504])
array([ 0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433])]
答案 3 :(得分:0)
数组构造中的不一致可能导致这种问题 例如
[[1,2,3,4], [2,3], [1],[1,2,3,4]]
这是不好的例子。