我试图在逐步遍历数组的每个 n 元素时获取 m 值。例如,对于 m = 2且 n = 5,并给出
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
我想要检索
b = [1, 2, 6, 7]
有没有办法使用切片来做到这一点?我可以使用嵌套列表理解来做到这一点,但我想知道是否有办法只使用索引来做到这一点。作为参考,列表理解方式是:
b = [k for j in [a[i:i+2] for i in range(0,len(a),5)] for k in j]
答案 0 :(得分:24)
我同意wim,你不能只用切片来做。但是你可以只用一个列表理解来做到这一点:
>>> [x for i,x in enumerate(a) if i%n < m]
[1, 2, 6, 7]
答案 1 :(得分:6)
不,切片不可能。切片仅支持开始,停止和步骤 - 无法用大于1的“组”表示步进。
答案 2 :(得分:5)
总之,不,你不能。但您可以使用itertools
来删除对中间列表的需求:
from itertools import chain, islice
res = list(chain.from_iterable(islice(a, i, i+2) for i in range(0, len(a), 5)))
print(res)
[1, 2, 6, 7]
借用@Kevin的逻辑,如果你想要一个矢量化解决方案以避免for
循环,你可以使用第三方库numpy
:
import numpy as np
m, n = 2, 5
a = np.array(a) # convert to numpy array
res = a[np.where(np.arange(a.shape[0]) % n < m)]
答案 3 :(得分:3)
还有其他方法可以做到这一点,在某些情况下都有优势,但没有一种只是“切片”。
最常见的解决方案可能是对输入进行分组,对组进行切片,然后将切片展平。这个解决方案的一个优点是你可以懒得地做,而不需要构建大的中间列表,你可以对任何迭代,包括一个惰性迭代器,而不仅仅是一个列表。
# from itertools recipes in the docs
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)
groups = grouper(a, 5)
truncated = (group[:2] for group in groups)
b = [elem for group in truncated for elem in group]
你可以把它转换成一个非常简单的单行,虽然你还需要grouper
函数:
b = [elem for group in grouper(a, 5) for elem in group[:2]]
另一种选择是建立索引列表,并使用itemgetter
来获取所有值。对于更复杂的功能而言,这可能比“每5个中的前2个”更具可读性,但对于像您使用这样简单的东西,它可能不太可读:
indices = [i for i in range(len(a)) if i%5 < 2]
b = operator.itemgetter(*indices)(a)
......可以变成单行:
b = operator.itemgetter(*[i for i in range(len(a)) if i%5 < 2])(a)
你可以通过编写自己的itemgetter
版本来结合这两种方法的优点,这些版本采用了一个懒惰的索引迭代器 - 我不会展示它,因为你可以通过编写一个索引过滤器功能:
def indexfilter(pred, a):
return [elem for i, elem in enumerate(a) if pred(i)]
b = indexfilter((lambda i: i%5<2), a)
(要使indexfilter
懒惰,只需用parens替换括号。)
......或者,作为一个单行:
b = [elem for i, elem in enumerate(a) if i%5<2]
我认为最后一个可能是最具可读性的。它适用于任何可迭代而不仅仅是列表,并且它可以变得懒惰(再次,只需用parens替换括号)。但我仍然不认为它比原来的理解更简单,而且它不只是切片。
答案 4 :(得分:2)
这个问题陈述了数组,如果我们谈论的是NumPy数组,我们肯定会使用一些明显的NumPy技巧和一些不那么明显的技巧。我们当然可以使用 public function GetSession($sessionId) {
$conn = mysqli_connect(DbConstants::$servername, DbConstants::$username, DbConstants::$password, DbConstants::$dbname);
if ($conn->connect_error) {
die("Connection failed: " . $conn->connect_error);
}
$query = 'CALL Sp_Session_GetById(' . $sessionId . ');';
mysqli_multi_query($conn, $query);
$sessionResult = mysqli_store_result($conn);
$sessionRow = mysqli_fetch_row($sessionResult);
$session = new Session(
$sessionRow[0],
$sessionRow[1],
$sessionRow[2],
$sessionRow[3],
$sessionRow[4],
$sessionRow[5],
$sessionRow[6],
$sessionRow[7],
$sessionRow[8],
[]);
mysqli_free_result($sessionResult);
mysqli_next_result($conn);
$sessionTypeResult = mysqli_store_result($conn);
while($sessionTypeRow = mysqli_fetch_row($sessionTypeResult)) {
array_push($session->sessionTypesForSession, $sessionTypeRow[0]);
}
$conn->close();
return $session;
}
在某些条件下将2D视图输入到输入中。
现在,根据数组长度,我们称之为slicing
和l
,我们将有三种情况:
m
可被l
我们可以使用切片和整形来获取输入数组的视图,从而获得恒定的运行时间。
验证视图概念:
n
检查非常大的数组上的时序,从而确定持续的运行时声明:
In [108]: a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
In [109]: m = 2; n = 5
In [110]: a.reshape(-1,n)[:,:m]
Out[110]:
array([[1, 2],
[6, 7]])
In [111]: np.shares_memory(a, a.reshape(-1,n)[:,:m])
Out[111]: True
要获得扁平化版本:
如果我们 得到一个展平的数组作为输出,我们只需要使用In [118]: a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
In [119]: %timeit a.reshape(-1,n)[:,:m]
1000000 loops, best of 3: 563 ns per loop
In [120]: a = np.arange(10000000)
In [121]: %timeit a.reshape(-1,n)[:,:m]
1000000 loops, best of 3: 564 ns per loop
的展平操作,就像这样 -
.ravel()
Timings表示,与其他帖子中的其他循环和矢量化numpy.where版本相比,它并不算太糟糕 -
In [127]: a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
In [128]: m = 2; n = 5
In [129]: a.reshape(-1,n)[:,:m].ravel()
Out[129]: array([1, 2, 6, 7])
In [143]: a = np.arange(10000000)
# @Kevin's soln
In [145]: %timeit [x for i,x in enumerate(a) if i%n < m]
1 loop, best of 3: 1.23 s per loop
# @jpp's soln
In [147]: %timeit a[np.where(np.arange(a.shape[0]) % n < m)]
10 loops, best of 3: 145 ms per loop
In [144]: %timeit a.reshape(-1,n)[:,:m].ravel()
100 loops, best of 3: 16.4 ms per loop
不能被l
整除,但这些组最后以完整的一个结尾我们使用np.lib.stride_tricks.as_strided
转到非显而易见的NumPy方法,允许使用beyoond内存块边界(因此我们需要注意不要写入那些)以使用{{1}来促进解决方案}。实现看起来像这样 -
n
运行示例以验证输出是slicing
-
def select_groups(a, m, n):
a = np.asarray(a)
strided = np.lib.stride_tricks.as_strided
# Get params defining the lengths for slicing and output array shape
nrows = len(a)//n
add0 = len(a)%n
s = a.strides[0]
out_shape = nrows+int(add0!=0),m
# Finally stride, flatten with reshape and slice
return strided(a, shape=out_shape, strides=(s*n,s))
要获得展平版本,请附加view
。
让我们进行一些时间比较 -
In [151]: a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
In [152]: m = 2; n = 5
In [153]: select_groups(a, m, n)
Out[153]:
array([[ 1, 2],
[ 6, 7],
[11, 12]])
In [154]: np.shares_memory(a, select_groups(a, m, n))
Out[154]: True
如果我们需要一个扁平的版本,它仍然不是太糟糕 -
.ravel()
In [158]: a = np.arange(10000003)
In [159]: m = 2; n = 5
# @Kevin's soln
In [161]: %timeit [x for i,x in enumerate(a) if i%n < m]
1 loop, best of 3: 1.24 s per loop
# @jpp's soln
In [162]: %timeit a[np.where(np.arange(a.shape[0]) % n < m)]
10 loops, best of 3: 148 ms per loop
In [160]: %timeit select_groups(a, m=m, n=n)
100000 loops, best of 3: 5.8 µs per loop
不能被In [163]: %timeit select_groups(a, m=m, n=n).ravel()
100 loops, best of 3: 16.5 ms per loop
整除,并且这些群组最后以不完整的结尾对于这种情况,我们需要在前一个方法的基础上进行额外的切片,就像这样 -
l
示例运行 -
n
计时 -
def select_groups_generic(a, m, n):
a = np.asarray(a)
strided = np.lib.stride_tricks.as_strided
# Get params defining the lengths for slicing and output array shape
nrows = len(a)//n
add0 = len(a)%n
lim = m*(nrows) + add0
s = a.strides[0]
out_shape = nrows+int(add0!=0),m
# Finally stride, flatten with reshape and slice
return strided(a, shape=out_shape, strides=(s*n,s)).reshape(-1)[:lim]
答案 5 :(得分:0)
使用itertools你可以得到一个迭代器:
from itertools import compress, cycle
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
n = 5
m = 2
it = compress(a, cycle([1, 1, 0, 0, 0]))
res = list(it)
答案 6 :(得分:0)
我意识到递归并不受欢迎,但这样的工作会不会这样?此外,不确定添加到混合的递归计数只是使用切片。
def get_elements(A, m, n):
if(len(A) < m):
return A
else:
return A[:m] + get_elements(A[n:], m, n)
A是数组,m和n在问题中定义。第一个if覆盖基本情况,其中有一个数组,其长度小于您尝试检索的元素数,第二个if是递归情况。我对python有点新鲜,请原谅我对该语言的不了解,如果这种方法不能正常工作,虽然我测试了它并且似乎工作正常。