在尝试使用numba加速某些代码时,我注意到与串行执行相比,并行执行给出了不同的(错误的)结果。考虑以下示例:
import numba
@numba.jit(nopython=True, parallel=False)
def get_nkj_serial(labels1, labels2, nunique_1, nunique_2):
labels1_uniques = np.arange(nunique_1)
labels2_uniques = np.arange(nunique_2)
n_kj = np.zeros((nunique_1, nunique_2))
for l in numba.prange(labels1.size):
l1, l2 = labels1[l], labels2[l]
for k, unique_1 in enumerate(labels1_uniques):
for j, unique_2 in enumerate(labels2_uniques):
if l2 == unique_2 and l1 == unique_1:
n_kj[k, j] += 1
return n_kj
@numba.jit(nopython=True, parallel=True)
def get_nkj_parallel(labels1, labels2, nunique_1, nunique_2):
labels1_uniques = np.arange(nunique_1)
labels2_uniques = np.arange(nunique_2)
n_kj = np.zeros((nunique_1, nunique_2))
for l in numba.prange(labels1.size):
l1, l2 = labels1[l], labels2[l]
for k, unique_1 in enumerate(labels1_uniques):
for j, unique_2 in enumerate(labels2_uniques):
if l2 == unique_2 and l1 == unique_1:
n_kj[k, j] += 1
return n_kj
labels1 = np.random.randint(0, 20, size=1000000)
labels2 = np.random.randint(0, 10, size=labels1.size)
nkj_serial = get_nkj_serial(labels1, labels2, 20, 10)
nkj_parallel = get_nkj_parallel(labels1, labels2, 20, 10)
print(nkj_serial)
print('\n')
print(nkj_parallel)
导致:
[[4974. 5088. 4973. 5002. 4870. 4890. 5056. 5090. 4851. 5069.]
[5023. 5098. 4928. 5033. 5039. 4994. 5001. 4927. 5101. 4988.]
[4945. 5028. 4845. 4819. 5053. 4983. 5072. 5100. 4884. 4989.]
[5029. 5032. 4872. 4989. 4989. 5058. 4930. 5067. 4887. 5024.]
[5012. 5006. 5045. 5003. 4942. 5124. 5036. 5145. 5004. 4929.]
[5137. 5061. 5109. 5064. 4944. 5177. 4937. 5049. 5008. 5079.]
[4889. 5086. 4888. 4956. 5107. 4983. 4979. 5008. 5004. 5098.]
[4997. 5004. 4998. 5101. 4864. 4949. 5062. 5055. 5008. 5036.]
[5087. 4953. 4972. 5022. 4889. 4816. 5017. 4947. 5011. 5054.]
[5150. 4994. 5091. 4905. 5019. 4940. 4965. 5011. 5129. 4952.]
[4934. 4912. 5019. 5171. 5070. 4950. 4956. 4889. 5063. 4949.]
[5000. 4906. 4971. 4888. 5129. 4981. 4955. 4924. 4965. 4985.]
[4985. 4979. 5034. 4970. 4935. 4824. 5074. 5005. 5047. 4897.]
[5024. 4950. 5012. 5005. 5061. 4960. 5091. 4926. 5040. 5147.]
[4943. 5070. 4984. 4899. 4948. 4904. 4972. 5015. 4928. 4979.]
[5000. 4944. 5010. 4986. 4947. 5029. 5002. 5066. 4978. 5033.]
[5013. 5146. 5021. 5061. 5079. 5087. 4971. 4995. 5020. 4958.]
[5088. 4875. 4976. 4923. 5033. 5117. 5154. 5112. 4911. 4976.]
[4961. 4908. 5027. 5093. 4993. 4980. 5099. 5018. 4932. 5092.]
[5010. 4842. 4993. 5005. 5013. 5061. 4947. 4972. 4919. 5002.]]
[[4573. 4685. 4605. 4629. 4517. 4560. 4650. 4706. 4436. 4639.]
[4551. 4634. 4479. 4584. 4604. 4534. 4554. 4463. 4629. 4534.]
[4515. 4570. 4411. 4394. 4585. 4522. 4568. 4602. 4469. 4528.]
[4543. 4499. 4418. 4556. 4579. 4569. 4496. 4624. 4476. 4576.]
[4527. 4521. 4603. 4517. 4500. 4652. 4547. 4644. 4574. 4458.]
[4629. 4582. 4611. 4592. 4475. 4663. 4512. 4571. 4540. 4576.]
[4445. 4619. 4462. 4493. 4612. 4493. 4467. 4519. 4520. 4628.]
[4362. 4504. 4456. 4574. 4384. 4497. 4541. 4529. 4491. 4506.]
[4533. 4429. 4435. 4468. 4351. 4305. 4498. 4422. 4495. 4520.]
[4582. 4502. 4560. 4444. 4531. 4441. 4425. 4454. 4590. 4410.]
[4409. 4385. 4503. 4561. 4565. 4462. 4443. 4396. 4515. 4451.]
[4494. 4419. 4440. 4344. 4608. 4462. 4419. 4413. 4453. 4446.]
[4505. 4518. 4548. 4490. 4443. 4387. 4527. 4387. 4565. 4461.]
[4536. 4481. 4545. 4521. 4603. 4490. 4636. 4514. 2677. 4645.]
[4503. 4611. 4520. 4462. 4509. 4500. 4526. 4570. 4485. 4538.]
[4555. 4489. 4576. 4573. 4524. 4597. 4540. 4606. 4561. 4583.]
[4547. 4658. 4582. 4604. 4610. 4638. 4506. 4564. 4587. 4565.]
[4651. 4451. 4535. 4536. 4566. 4659. 4679. 4688. 4495. 4547.]
[4535. 4479. 4608. 4591. 4565. 4571. 4659. 4617. 4509. 4674.]
[4570. 4452. 4577. 4573. 4583. 4638. 4582. 4544. 4512. 4600.]]
如您所见,串行版本的结果与并行版本不同。这是一个错误,还是我做错了什么?
我有numba=0.39
,llvmlite=0.24.0