在python中查找矩阵中的最大索引

时间:2017-08-13 00:58:51

标签: python list matrix set

我有一个矩阵。我想在矩阵的每一行中找到最大值的索引。 问题是这个矩阵是另一种算法的结果,所以我不能用笨拙的方法来做。

[[  6.02   6.02   6.02   6.02
    6.02  6.02   6.02  9.94
    6.02  6.02]
 [  4.63  4.63  4.63  4.63
    4.63  4.63   4.63  4.63
    9.95  4.63]
 [  4.54   4.54   4.54  4.54
    4.54   9.95   4.54   4.54
    4.54   4.54]]

所以这个woill的输出是: [7,8,5]

我想循环遍历矩阵中的每一行,然后我有一行,但问题是它是这样的:

   a =[  6.02   6.02   6.02   6.02
        6.02  6.02   6.02  9.94
        6.02  6.02]

如您所见,没有格式。我的意思是它们之间没有逗号,所以我不能再像set或list那样表现。

当我写a.时,没有方法,

更新

这实际上是此lda_x=lda.fit_transform(corpus)

的输出

赞赏任何想法

1 个答案:

答案 0 :(得分:0)

一般性评论:

很明显,你缺少很多编程基础知识,应该阅读更多的教程。在不知情的情况下使用sklearn,它完全基于numpy和scipy数组,显示出一些缺乏知识。

我不认为你应该在没有理解这些概念的情况下做那样复杂的ML之类的东西。

代码(只是为了显示步骤;代码没有多大意义!):

import numpy as np
from sklearn.decomposition import LatentDirichletAllocation


X = np.random.uniform(size=(10,20))

lda = LatentDirichletAllocation(max_iter=5,
                                learning_method='online',
                                learning_offset=50.,
                                random_state=0)
X_new = lda.fit_transform(X)

print('transformed data')
print(X_new)

print('row-wise maximums')
print(np.amax(X_new, axis=1))

输出:

transformed data
[[ 0.00827835  0.00827806  0.00827819  0.92549556  0.00827816  0.00827812
   0.00827862  0.00827827  0.00827855  0.00827812]
 [ 0.00864627  0.00864591  0.00864594  0.92218549  0.00864589  0.00864588
   0.00864651  0.0086461   0.00864617  0.00864584]
 [ 0.00761386  0.00761349  0.00761361  0.93147708  0.0076135   0.00761359
   0.00761399  0.00761363  0.00761376  0.0076135 ]
 [ 0.00987778  0.0098773   0.00987748  0.91110236  0.00987731  0.00987739
   0.00987797  0.00987747  0.00987762  0.00987731]
 [ 0.00902489  0.00902451  0.00902467  0.9187762   0.00902463  0.00902457
   0.00902629  0.00902484  0.00902484  0.00902455]
 [ 0.0106997   0.01069926  0.01069952  0.90370346  0.01069944  0.0106994
   0.01070011  0.01069987  0.01069985  0.01069938]
 [ 0.00869799  0.00869761  0.00869767  0.92172003  0.00869759  0.0086976
   0.00869823  0.00869782  0.00869789  0.00869756]
 [ 0.00830867  0.0083084   0.00830849  0.92522278  0.00830843  0.00830836
   0.00830908  0.00830855  0.00830887  0.00830837]
 [ 0.00873503  0.00873454  0.00873457  0.92138779  0.00873453  0.00873451
   0.00873501  0.00873467  0.00873481  0.00873453]
 [ 0.01097472  0.01097405  0.01097429  0.90123112  0.01097416  0.01097427
   0.01097458  0.01097437  0.01097436  0.01097409]]
row-wise maximums
[ 0.92549556  0.92218549  0.93147708  0.91110236  0.9187762   0.90370346
  0.92172003  0.92522278  0.92138779  0.90123112]

如果您想要索引而不是值:

print(np.argmax(X_new, axis=1))

输出:

row-wise maximums
[3 3 3 3 3 3 3 3 3 3]