我不明白为什么设置矩阵
a3
具有12个元素,并且12 * 0.8 = 9.6。 9.6个元素如何保留?我的错误在哪里?
我的代码:
import numpy as np
keep_prod = 0.8
a3 = np.random.rand(3,4)
print("a3-before",a3)
d3 = np.random.rand(a3.shape[0],a3.shape[1])<keep_prod ##### attention!!!
print("d3",d3)
输出:
a3-before
[[ 0.6016695 0.733025 0.38694513 0.17916196]
[ 0.39412193 0.22803599 0.16931667 0.30190426]
[ 0.8822327 0.64064634 0.40085393 0.72317028]]
d3
[[False True True False]
[ True False True True]
[ True True True True]]
答案 0 :(得分:0)
您似乎希望< 0.8
将80%的元素保留在原始数组中。
但是,真正发生的是它执行了 elementwise 比较,并返回了一个相同形状的数组。换句话说,d3
包含True
,其中a3
的相应元素大于或等于0.8
,而其他地方的False
。