将输入字符串转换为numpy切片

时间:2018-01-29 04:53:13

标签: python numpy

我正在使用numpy切片作为输入,并希望知道numpy将其转换为什么。

例如,如果用户输入字符串是user_input = '[1:10, 2:20:2]',并且我有一个数组arr,我可以使用eval('arr' + user_input)从数组中获取结果切片。

但我正在寻找的东西更接近于slice(1,10,1), slice(2,20,2)或者在进行实际索引之前的任何numpy转换。

无论如何都可以访问它吗?

如何获得numpy切片解释的中间步骤,而不是实际的输出数组?或者我在哪里可以查看numpy如何解释其索引?

2 个答案:

答案 0 :(得分:0)

要将用户输入字符串转换为切片,您可以解析字符串并将其传递给切片,如:

代码:

def slice_from_string(slice_string):
    slices = slice_string.split(',')
    if len(slices) > 1:
        return [slice_from_string(s.strip()) for s in slices]
    return slice(*[int(x) for x in slice_string.split(':')])

测试代码:

import numpy as np

nums = np.arange(100)
print(nums[slice_from_string('3:7')])
print(nums[slice_from_string('2:20:2')])

nums = np.mgrid[1:10, 1:10][0]
print(nums)
print(nums[slice_from_string('3:7, 1:3')])

测试结果:

[3 4 5 6]

[ 2  4  6  8 10 12 14 16 18]

[[1 1 1 1 1 1 1 1 1]
 [2 2 2 2 2 2 2 2 2]
 [3 3 3 3 3 3 3 3 3]
 [4 4 4 4 4 4 4 4 4]
 [5 5 5 5 5 5 5 5 5]
 [6 6 6 6 6 6 6 6 6]
 [7 7 7 7 7 7 7 7 7]
 [8 8 8 8 8 8 8 8 8]
 [9 9 9 9 9 9 9 9 9]]

[[4 4]
 [5 5]
 [6 6]
 [7 7]]

答案 1 :(得分:0)

如果我定义一个虚拟类

class Foo():
    def __getitem__(self, atuple):
        #print(atuple)
        return atuple

然后我可以使用eval将表达式转换为带有切片的元组:

In [140]: x=eval('f[0:2, 1:10:2,[1,2,2],4]')
In [141]: x
Out[141]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)

In [142]: astr = '[0:2, 1:10:2,[1,2,2],4]'
In [143]: eval('f{}'.format(astr))
Out[143]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)

此元组可用于索引:

In [144]: arr = np.arange(3*10*3*5).reshape(3,10,3,5)
In [145]: arr[x]
Out[145]: 
array([[[ 24,  29,  29],
        [ 54,  59,  59],
        [ 84,  89,  89],
        [114, 119, 119],
        [144, 149, 149]],

       [[174, 179, 179],
        [204, 209, 209],
        [234, 239, 239],
        [264, 269, 269],
        [294, 299, 299]]])

ast模块具有更安全的ast.literal_eval,但它不处理索引。

正如我评论的那样,numpy/lib/index_tricks.py有一些有趣的例子,使用类语法作为函数语法的替代。 np.r_np.ogridnp.s_是最有用的示例:

In [150]: np.s_[0:2, 1:10:2,[1,2,2],4]
Out[150]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)

实际上我不需要定义Foo类;我可以使用np.s_代替

In [151]: eval('np.s_{}'.format(astr))
Out[151]: (slice(0, 2, None), slice(1, 10, 2), [1, 2, 2], 4)

arr[x]已翻译为arr.__getitem__(x)。但是,正确解释编译数组__getitem__内部的内容超出了我的专业知识。