Numpy 3D阵列最大值

时间:2018-07-13 20:51:12

标签: python numpy

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
           [ 0.50,  0.60 ,  0.70 ,  0.30]],
          [[ 0.25,  0.50 ,  0.20 ,  0.70],
           [ 0.80,  0.10 ,  0.50 ,  0.15]]])

我需要在a [i]中找到最大值的行和列。 如果i = 0,则a [0,1,2]为最大值,因此我需要编写一种方法,将[1,2]作为a [0]中max的输出。有指针吗? 注意:np.argmax展平a [i] 2D数组,并且当使用axis = 0时,它给出a [0]每行中max的索引

3 个答案:

答案 0 :(得分:3)

您也可以将argmaxunravel_index一起使用:

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])

例如

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
               [ 0.50,  0.60 ,  0.70 ,  0.30]],
              [[ 0.25,  0.50 ,  0.20 ,  0.70],
               [ 0.80,  0.10 ,  0.50 ,  0.15]]])

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])


print(max_by_index(0, a))

给予

(0, 1, 2)

答案 1 :(得分:2)

您可以使用numpy.where,可以将其包装成一个简单的函数来满足您的要求:

# ----
# Site

title: Base
url: "https://kb.example.com/"
baseurl:
google_analytics_key:
disqus_shortname:
newsletter_action:

# Values for the jekyll-seo-tag gem (https://github.com/jekyll/jekyll-seo-tag)
logo: /siteicon.png
description: Knowledge base template for Jekyll.
author:
  name:
  email:
  twitter: # twitter username without the @ symbol
social:
  name: Base Template
  links:
    - https://github.com/CloudCannon/base-jekyll-template

# -----
# Build

timezone: Etc/UTC

permalink: /:categories/:title/

plugins:
  - jekyll-extract-element
  - jekyll-sitemap
  - jekyll-seo-tag
  - jekyll-feed
  - jekyll-archives

exclude:
  - Gemfile
  - Gemfile.lock
  - README.md
  - LICENCE

collections:
  sets:


jekyll-archives:
  enabled: ['categories']

defaults:
  -
    scope:
      path: ""
    values:
      layout: "default"
  -
    scope:
      type: "posts"
    values:
      layout: "post"
      comments: true
  -
    scope:
      type: "sets"
    values:
      _hide_content: true
  -
    scope:
      path: "index.html"
    values:
      body_class: "show_hero_search"

# -----------
# CloudCannon

social_icons:
  - Facebook
  - Google Plus
  - Instagram
  - LinkedIn
  - Pinterest
  - Tumblr
  - Twitter
  - YouTube
  - Email
  - RSS

types:
  - Document
  - Video

_comments:

实际情况:

def max_by_index(idx, arr):
    return np.where(arr[idx] == np.max(arr[idx]))

您可以使用此结果对数组进行索引以访问最大值:

>>> max_by_index(0, a)
(array([1], dtype=int64), array([2], dtype=int64))

这将返回最大值的所有个位置,如果只希望一次出现,则可以将 >>> a[0][max_by_index(0, a)] array([0.7]) 替换为 {{1 }}

答案 2 :(得分:0)

col = (np.argmax(a[i])) % (a[i].shape[1])
row = (np.argmax(a[i])) // (a[i].shape[1])

这也有帮助