我有一个矩阵。我想在矩阵的每一行中找到最大值的索引。 问题是这个矩阵是另一种算法的结果,所以我不能用笨拙的方法来做。
[[ 6.02 6.02 6.02 6.02
6.02 6.02 6.02 9.94
6.02 6.02]
[ 4.63 4.63 4.63 4.63
4.63 4.63 4.63 4.63
9.95 4.63]
[ 4.54 4.54 4.54 4.54
4.54 9.95 4.54 4.54
4.54 4.54]]
所以这个woill的输出是: [7,8,5]
我想循环遍历矩阵中的每一行,然后我有一行,但问题是它是这样的:
a =[ 6.02 6.02 6.02 6.02
6.02 6.02 6.02 9.94
6.02 6.02]
如您所见,没有格式。我的意思是它们之间没有逗号,所以我不能再像set或list那样表现。
当我写a.
时,没有方法,
更新
这实际上是此lda_x=lda.fit_transform(corpus)
赞赏任何想法
答案 0 :(得分:0)
一般性评论:
很明显,你缺少很多编程基础知识,应该阅读更多的教程。在不知情的情况下使用sklearn,它完全基于numpy和scipy数组,显示出一些缺乏知识。
我不认为你应该在没有理解这些概念的情况下做那样复杂的ML之类的东西。
import numpy as np
from sklearn.decomposition import LatentDirichletAllocation
X = np.random.uniform(size=(10,20))
lda = LatentDirichletAllocation(max_iter=5,
learning_method='online',
learning_offset=50.,
random_state=0)
X_new = lda.fit_transform(X)
print('transformed data')
print(X_new)
print('row-wise maximums')
print(np.amax(X_new, axis=1))
transformed data
[[ 0.00827835 0.00827806 0.00827819 0.92549556 0.00827816 0.00827812
0.00827862 0.00827827 0.00827855 0.00827812]
[ 0.00864627 0.00864591 0.00864594 0.92218549 0.00864589 0.00864588
0.00864651 0.0086461 0.00864617 0.00864584]
[ 0.00761386 0.00761349 0.00761361 0.93147708 0.0076135 0.00761359
0.00761399 0.00761363 0.00761376 0.0076135 ]
[ 0.00987778 0.0098773 0.00987748 0.91110236 0.00987731 0.00987739
0.00987797 0.00987747 0.00987762 0.00987731]
[ 0.00902489 0.00902451 0.00902467 0.9187762 0.00902463 0.00902457
0.00902629 0.00902484 0.00902484 0.00902455]
[ 0.0106997 0.01069926 0.01069952 0.90370346 0.01069944 0.0106994
0.01070011 0.01069987 0.01069985 0.01069938]
[ 0.00869799 0.00869761 0.00869767 0.92172003 0.00869759 0.0086976
0.00869823 0.00869782 0.00869789 0.00869756]
[ 0.00830867 0.0083084 0.00830849 0.92522278 0.00830843 0.00830836
0.00830908 0.00830855 0.00830887 0.00830837]
[ 0.00873503 0.00873454 0.00873457 0.92138779 0.00873453 0.00873451
0.00873501 0.00873467 0.00873481 0.00873453]
[ 0.01097472 0.01097405 0.01097429 0.90123112 0.01097416 0.01097427
0.01097458 0.01097437 0.01097436 0.01097409]]
row-wise maximums
[ 0.92549556 0.92218549 0.93147708 0.91110236 0.9187762 0.90370346
0.92172003 0.92522278 0.92138779 0.90123112]
如果您想要索引而不是值:
print(np.argmax(X_new, axis=1))
输出:
row-wise maximums
[3 3 3 3 3 3 3 3 3 3]