我有一个用于读取多维数组的API,需要传递范围向量以从后备数组中读取子矩形(或超立方体)。我想要“线性地”读取这个数组,所有元素都按照给定的顺序排列,具有任意的块大小。因此,任务使用off
和len
,将此范围覆盖的元素转换为最小可能的超立方体集,即API中发出的最小数量的读取命令。
例如,我们可以计算给出线性索引的维度集合的索引向量:
def calcIndices(off: Int, shape: Vector[Int]): Vector[Int] = {
val modsDivs = shape zip shape.scanRight(1)(_ * _).tail
modsDivs.map { case (mod, div) =>
(off / div) % mod
}
}
假设形状是这样,表示总共有4级和120级元素的数组:
val sz = Vector(2, 3, 4, 5)
val num = sz.product // 120
为一系列线性偏移打印这些索引向量的实用程序:
def printIndices(off: Int, len: Int): Unit =
(off until (off + len)).map(calcIndices(_, sz))
.map(_.mkString("[", ", ", "]")).foreach(println)
我们可以生成所有这些向量:
printIndices(0, num)
[0, 0, 0, 0]
[0, 0, 0, 1]
[0, 0, 0, 2]
[0, 0, 0, 3]
[0, 0, 0, 4]
[0, 0, 1, 0]
[0, 0, 1, 1]
[0, 0, 1, 2]
[0, 0, 1, 3]
[0, 0, 1, 4]
[0, 0, 2, 0]
[0, 0, 2, 1]
[0, 0, 2, 2]
[0, 0, 2, 3]
[0, 0, 2, 4]
[0, 0, 3, 0]
[0, 0, 3, 1]
[0, 0, 3, 2]
[0, 0, 3, 3]
[0, 0, 3, 4]
[0, 1, 0, 0]
...
[1, 2, 1, 4]
[1, 2, 2, 0]
[1, 2, 2, 1]
[1, 2, 2, 2]
[1, 2, 2, 3]
[1, 2, 2, 4]
[1, 2, 3, 0]
[1, 2, 3, 1]
[1, 2, 3, 2]
[1, 2, 3, 3]
[1, 2, 3, 4]
让我们看一下应该读取的示例块, 前六个要素:
val off1 = 0
val len1 = 6
printIndices(off1, len1)
我已经手动将输出分区为超立方体:
// first hypercube or read
[0, 0, 0, 0]
[0, 0, 0, 1]
[0, 0, 0, 2]
[0, 0, 0, 3]
[0, 0, 0, 4]
// second hypercube or read
[0, 0, 1, 0]
所以任务是定义方法
def partition(shape: Vector[Int], off: Int, len: Int): List[Vector[Range]]
输出正确的列表并使用尽可能小的列表大小。
因此,对于off1
和len1
,我们有预期的结果:
val res1 = List(
Vector(0 to 0, 0 to 0, 0 to 0, 0 to 4),
Vector(0 to 0, 0 to 0, 1 to 1, 0 to 0)
)
assert(res1.map(_.map(_.size).product).sum == len1)
第二个例子,索引6到22的元素,手动分区给出三个超立方体或读取命令:
val off2 = 6
val len2 = 16
printIndices(off2, len2)
// first hypercube or read
[0, 0, 1, 1]
[0, 0, 1, 2]
[0, 0, 1, 3]
[0, 0, 1, 4]
// second hypercube or read
[0, 0, 2, 0]
[0, 0, 2, 1]
[0, 0, 2, 2]
[0, 0, 2, 3]
[0, 0, 2, 4]
[0, 0, 3, 0]
[0, 0, 3, 1]
[0, 0, 3, 2]
[0, 0, 3, 3]
[0, 0, 3, 4]
// third hypercube or read
[0, 1, 0, 0]
[0, 1, 0, 1]
expected result:
val res2 = List(
Vector(0 to 0, 0 to 0, 1 to 1, 1 to 4),
Vector(0 to 0, 0 to 0, 2 to 3, 0 to 4),
Vector(0 to 0, 1 to 1, 0 to 0, 0 to 1)
)
assert(res2.map(_.map(_.size).product).sum == len2)
请注意,对于val off3 = 6; val len3 = 21
,我们需要四个读数。
答案 0 :(得分:0)
以下算法的想法如下:
[0, 0, 0, 1]
和[0, 1, 0, 0]
,poi是1)[0, 0, 1, 1]
和poi = 2
,ceil将为[0, 0, 2, 0]
[0, 0, 1, 1]
和poi = 2
,最低限额为[0, 0, 1, 0]
这是我的实施。首先,一些实用功能:
def calcIndices(off: Int, shape: Vector[Int]): Vector[Int] = {
val modsDivs = (shape, shape.scanRight(1)(_ * _).tail, shape.indices).zipped
modsDivs.map { case (mod, div, idx) =>
val x = off / div
if (idx == 0) x else x % mod
}
}
def calcPOI(a: Vector[Int], b: Vector[Int], min: Int): Int = {
val res = (a.drop(min) zip b.drop(min)).indexWhere { case (ai,bi) => ai != bi }
if (res < 0) a.size else res + min
}
def zipToRange(a: Vector[Int], b: Vector[Int]): Vector[Range] =
(a, b).zipped.map { (ai, bi) =>
require (ai <= bi)
ai to bi
}
def calcOff(a: Vector[Int], shape: Vector[Int]): Int = {
val divs = shape.scanRight(1)(_ * _).tail
(a, divs).zipped.map(_ * _).sum
}
def indexTrunc(a: Vector[Int], poi: Int, inc: Boolean): Vector[Int] =
a.zipWithIndex.map { case (ai, i) =>
if (i < poi) ai
else if (i > poi) 0
else if (inc) ai + 1
else ai
}
然后是实际的算法:
def partition(shape: Vector[Int], off: Int, len: Int): List[Vector[Range]] = {
val rankM = shape.size - 1
def loop(start: Int, stop: Int, poiMin: Int, dir: Boolean,
res0: List[Vector[Range]]): List[Vector[Range]] =
if (start == stop) res0 else {
val last = stop - 1
val s0 = calcIndices(start, shape)
val s1 = calcIndices(stop , shape)
val s1m = calcIndices(last , shape)
val poi = calcPOI(s0, s1m, poiMin)
val ti = if (dir) s0 else s1
val to = if (dir) s1 else s0
val st = if (poi >= rankM) to else indexTrunc(ti, poi, inc = dir)
val trunc = calcOff(st, shape)
val split = trunc != (if (dir) stop else start)
if (split) {
if (dir) {
val res1 = loop(start, trunc, poiMin = poi+1, dir = true , res0 = res0)
loop (trunc, stop , poiMin = 0 , dir = false, res0 = res1)
} else {
val s1tm = calcIndices(trunc - 1, shape)
val res1 = zipToRange(s0, s1tm) :: res0
loop (trunc, stop , poiMin = poi+1, dir = false, res0 = res1)
}
} else {
zipToRange(s0, s1m) :: res0
}
}
loop(off, off + len, poiMin = 0, dir = true, res0 = Nil).reverse
}
示例:
val sz = Vector(2, 3, 4, 5)
partition(sz, 0, 6)
// result:
List(
Vector(0 to 0, 0 to 0, 0 to 0, 0 to 4), // first hypercube
Vector(0 to 0, 0 to 0, 1 to 1, 0 to 0) // second hypercube
)
partition(sz, 6, 21)
// result:
List(
Vector(0 to 0, 0 to 0, 1 to 1, 1 to 4), // first read
Vector(0 to 0, 0 to 0, 2 to 3, 0 to 4), // second read
Vector(0 to 0, 1 to 1, 0 to 0, 0 to 4), // third read
Vector(0 to 0, 1 to 1, 1 to 1, 0 to 1) // fourth read
)
如果我没有弄错,最大读取次数为2 * rank
。