如何将2d numpy数组转换为二进制指示符矩阵以获得最大值

时间:2016-03-22 11:50:42

标签: python python-2.7 numpy machine-learning

假设我有一个2d numpy数组,表示n个类中m个样本的概率(每个样本的概率总和为1)。

假设每个样本只能在一个类别中,我想创建一个与原始样本具有相同形状的新数组,但只有二进制值表示哪个类的概率最高。

示例:

[[0.2, 0.3, 0.5], [0.7, 0.1, 0.1]]

应转换为:

[[0, 0, 1], [1, 0, 0]]

似乎amax几乎已经做了我想要的,但是我想要一个指标矩阵而不是索引,如上所述。

看起来很简单,但不知怎的,我无法使用标准的numpy函数来解决这个问题。我当然可以使用常规的python循环,但似乎应该有一种更简单的方法。

如果多个类具有相同的概率,我宁愿选择只选择其中一个类的解决方案(在这种情况下我不关心它)。

谢谢!

2 个答案:

答案 0 :(得分:10)

这是一种方式:

In [112]: a
Out[112]: 
array([[ 0.2,  0.3,  0.5],
       [ 0.7,  0.1,  0.1]])

In [113]: a == a.max(axis=1, keepdims=True)
Out[113]: 
array([[False, False,  True],
       [ True, False, False]], dtype=bool)

In [114]: (a == a.max(axis=1, keepdims=True)).astype(int)
Out[114]: 
array([[0, 0, 1],
       [1, 0, 0]])

(但这会给每行出现最大值的一个真值。请参阅Divakar的答案,以便选择最初出现的最大值。)

答案 1 :(得分:5)

如果是关系(两个或多个元素是连续的最高元素),您只想选择一个元素,这是使用np.argmax和{{3}执行此操作的方法之一} -

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE log4j:configuration PUBLIC "-//APACHE//DTD LOG4J 1.2//EN" "log4j.dtd">
<log4j:configuration xmlns:log4j='http://jakarta.apache.org/log4j/'>

  <appender name="STDOUT" class="org.apache.log4j.ConsoleAppender">
    <layout class="org.apache.log4j.PatternLayout">
      <param name="ConversionPattern" value="%d [%p] %c{1} - %m%n"/>
    </layout>
  </appender>

  <logger name="org.mybatis.spring" additivity="false">
    <level value="debug"/>
    <appender-ref ref="STDOUT"/>
  </logger>

  <logger name="com.sample.mappers">
    <level value="debug"/>
    <appender-ref ref="STDOUT"/>
  </logger>

  <!-- Other custom 3rd party logger configs -->

  <root>
    <priority value ="debug" />
    <appender-ref ref="STDOUT" />
  </root>

</log4j:configuration>

示例运行 -

(A.argmax(1)[:,None] == np.arange(A.shape[1])).astype(int)