我正在使用numpy
切片作为输入,并希望知道numpy
将其转换为什么。
例如,如果用户输入字符串是user_input = '[1:10, 2:20:2]'
,并且我有一个数组arr
,我可以使用eval('arr' + user_input)
从数组中获取结果切片。
但我正在寻找的东西更接近于slice(1,10,1), slice(2,20,2)
或者在进行实际索引之前的任何numpy转换。
无论如何都可以访问它吗?
如何获得numpy切片解释的中间步骤,而不是实际的输出数组?或者我在哪里可以查看numpy如何解释其索引?
答案 0 :(得分:0)
要将用户输入字符串转换为切片,您可以解析字符串并将其传递给切片,如:
def slice_from_string(slice_string):
slices = slice_string.split(',')
if len(slices) > 1:
return [slice_from_string(s.strip()) for s in slices]
return slice(*[int(x) for x in slice_string.split(':')])
import numpy as np
nums = np.arange(100)
print(nums[slice_from_string('3:7')])
print(nums[slice_from_string('2:20:2')])
nums = np.mgrid[1:10, 1:10][0]
print(nums)
print(nums[slice_from_string('3:7, 1:3')])
[3 4 5 6]
[ 2 4 6 8 10 12 14 16 18]
[[1 1 1 1 1 1 1 1 1]
[2 2 2 2 2 2 2 2 2]
[3 3 3 3 3 3 3 3 3]
[4 4 4 4 4 4 4 4 4]
[5 5 5 5 5 5 5 5 5]
[6 6 6 6 6 6 6 6 6]
[7 7 7 7 7 7 7 7 7]
[8 8 8 8 8 8 8 8 8]
[9 9 9 9 9 9 9 9 9]]
[[4 4]
[5 5]
[6 6]
[7 7]]
答案 1 :(得分:0)
如果我定义一个虚拟类
class Foo():
def __getitem__(self, atuple):
#print(atuple)
return atuple
然后我可以使用eval
将表达式转换为带有切片的元组:
In [140]: x=eval('f[0:2, 1:10:2,[1,2,2],4]')
In [141]: x
Out[141]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)
In [142]: astr = '[0:2, 1:10:2,[1,2,2],4]'
In [143]: eval('f{}'.format(astr))
Out[143]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)
此元组可用于索引:
In [144]: arr = np.arange(3*10*3*5).reshape(3,10,3,5)
In [145]: arr[x]
Out[145]:
array([[[ 24, 29, 29],
[ 54, 59, 59],
[ 84, 89, 89],
[114, 119, 119],
[144, 149, 149]],
[[174, 179, 179],
[204, 209, 209],
[234, 239, 239],
[264, 269, 269],
[294, 299, 299]]])
ast
模块具有更安全的ast.literal_eval
,但它不处理索引。
正如我评论的那样,numpy/lib/index_tricks.py
有一些有趣的例子,使用类语法作为函数语法的替代。 np.r_
,np.ogrid
和np.s_
是最有用的示例:
In [150]: np.s_[0:2, 1:10:2,[1,2,2],4]
Out[150]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)
实际上我不需要定义Foo
类;我可以使用np.s_
代替
In [151]: eval('np.s_{}'.format(astr))
Out[151]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)
arr[x]
已翻译为arr.__getitem__(x)
。但是,正确解释编译数组__getitem__
内部的内容超出了我的专业知识。