import numpy as np
a = np.array([[[ 0.25, 0.10 , 0.50 , 0.15],
[ 0.50, 0.60 , 0.70 , 0.30]],
[[ 0.25, 0.50 , 0.20 , 0.70],
[ 0.80, 0.10 , 0.50 , 0.15]]])
我需要在a [i]中找到最大值的行和列。 如果i = 0,则a [0,1,2]为最大值,因此我需要编写一种方法,将[1,2]作为a [0]中max的输出。有指针吗? 注意:np.argmax展平a [i] 2D数组,并且当使用axis = 0时,它给出a [0]每行中max的索引
答案 0 :(得分:3)
您也可以将argmax
与unravel_index
一起使用:
def max_by_index(idx, arr):
return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])
例如
import numpy as np
a = np.array([[[ 0.25, 0.10 , 0.50 , 0.15],
[ 0.50, 0.60 , 0.70 , 0.30]],
[[ 0.25, 0.50 , 0.20 , 0.70],
[ 0.80, 0.10 , 0.50 , 0.15]]])
def max_by_index(idx, arr):
return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])
print(max_by_index(0, a))
给予
(0, 1, 2)
答案 1 :(得分:2)
您可以使用numpy.where
,可以将其包装成一个简单的函数来满足您的要求:
# ----
# Site
title: Base
url: "https://kb.example.com/"
baseurl:
google_analytics_key:
disqus_shortname:
newsletter_action:
# Values for the jekyll-seo-tag gem (https://github.com/jekyll/jekyll-seo-tag)
logo: /siteicon.png
description: Knowledge base template for Jekyll.
author:
name:
email:
twitter: # twitter username without the @ symbol
social:
name: Base Template
links:
- https://github.com/CloudCannon/base-jekyll-template
# -----
# Build
timezone: Etc/UTC
permalink: /:categories/:title/
plugins:
- jekyll-extract-element
- jekyll-sitemap
- jekyll-seo-tag
- jekyll-feed
- jekyll-archives
exclude:
- Gemfile
- Gemfile.lock
- README.md
- LICENCE
collections:
sets:
jekyll-archives:
enabled: ['categories']
defaults:
-
scope:
path: ""
values:
layout: "default"
-
scope:
type: "posts"
values:
layout: "post"
comments: true
-
scope:
type: "sets"
values:
_hide_content: true
-
scope:
path: "index.html"
values:
body_class: "show_hero_search"
# -----------
# CloudCannon
social_icons:
- Facebook
- Google Plus
- Instagram
- LinkedIn
- Pinterest
- Tumblr
- Twitter
- YouTube
- Email
- RSS
types:
- Document
- Video
_comments:
实际情况:
def max_by_index(idx, arr):
return np.where(arr[idx] == np.max(arr[idx]))
您可以使用此结果对数组进行索引以访问最大值:
>>> max_by_index(0, a)
(array([1], dtype=int64), array([2], dtype=int64))
这将返回最大值的所有个位置,如果只希望一次出现,则可以将 >>> a[0][max_by_index(0, a)]
array([0.7])
替换为 {{1 }} 。
答案 2 :(得分:0)
col = (np.argmax(a[i])) % (a[i].shape[1])
row = (np.argmax(a[i])) // (a[i].shape[1])
这也有帮助