我正在使用python burst_detection包链接(https://github.com/nmarinsek/burst_detection/blob/master/README.rst)来尝试结果。但程序总是有一些错误:ValueError:用序列设置数组元素。 有趣的是当我删除r数组中大于94的值时,一切正常。它停在第84个元素,因为r [83] = 342> 94。
错误的详细信息如下:
C:\Python27\lib\site-packages\burst_detection\__init__.py:29: RuntimeWarning: invalid value encountered in double_scalars
return -np.log(np.float(c.binomial(d,r)) * (p**r) * (1-p)**(d-r))
Traceback (most recent call last):
File "burstConstructDataPrepare.py", line 33, in <module>
q, d, r, p = bd.burst_detection(r,d,n,s=2,gamma=1,smooth_win=1)
File "C:\Python27\lib\site-packages\burst_detection\__init__.py", line 82, in burst_detection
q[t] = np.where(cost[t,:] == min(cost[t,:]))
ValueError: setting an array element with a sequence.
源代码:
import pandas as pd
import os
import burst_detection as bd
import numpy as np
r = np.array([5, 8, 12, 12, 2, 4, 11, 15, 2, 3, 4, 29, 30, 10, 7, 1, 24, 18, 2, 2, 2, 2, 54, 2, 8, 2, 3, 12, 4, 2, 6, 18, 4, 4, 12, 2, 8, 2, 3, 2, 2, 5, 2, 2, 3, 9, 7, 8, 6, 9, 6, 1, 4, 20, 2, 16, 19, 2, 11, 4, 2, 38, 6, 7, 1, 2, 14, 4, 8, 2, 4, 2, 2, 6, 8, 27, 4, 2, 14, 2, 14, 8, 4, 2, 342, 4, 2, 2, 14, 14, 6, 2, 2, 6, 2, 4, 2, 1, 5, 10, 27, 6, 2, 2, 2, 14, 12, 16, 2, 48, 16, 6, 3, 2, 4, 2, 2, 1, 320, 4, 4, 8, 6, 238, 12, 6, 4, 10, 6, 2, 10, 4, 19, 10, 3, 1, 2, 32, 8, 4, 6, 2, 4, 2, 18, 10, 18, 4, 4, 4, 6, 2, 13, 2, 4, 47, 2, 2, 4, 10, 5, 4, 2, 12, 34, 4, 6, 8, 8, 8, 20, 2, 1, 4, 6, 2, 8, 29, 4, 14, 6, 8, 2, 28, 4, 18, 2, 2, 2, 7, 4, 2, 2, 8, 2, 6, 8, 1, 2, 6, 4, 1, 246, 6, 43, 14, 16, 2, 7, 4, 4, 12, 8, 8, 2, 14, 4, 19, 4, 2, 8, 16, 8, 14, 3, 12, 3, 4, 3, 4, 6, 2, 16, 5, 2, 3, 2, 11, 301, 10, 2, 2, 8, 1, 2, 4, 4, 4, 8, 2, 2, 4, 2, 8, 4, 6, 2, 34, 14, 25, 11, 2, 5, 34, 2, 1, 2, 2, 6, 2, 1, 6, 4, 4, 2, 5, 2, 2, 1, 25, 1, 21, 10, 14, 10, 4, 4, 6, 4, 4, 4, 28, 36, 2, 7, 2, 1, 5, 5, 2, 8, 23, 2, 104, 4, 2, 81, 5, 10, 4, 2, 20, 4, 4, 12, 4, 4, 7, 2, 2, 6, 1, 2, 4, 2, 16, 4, 2, 2, 32, 26, 2, 3, 5, 8, 34, 11, 2, 15, 4803, 6, 4, 2, 7, 2, 6, 54, 15, 5, 2, 10, 8, 6, 10, 2, 2, 4, 4, 4, 10, 6, 4, 7, 10, 12, 2, 2, 4, 10, 18, 2, 2, 4, 6, 4, 2, 10, 4, 2, 3, 4, 7, 5, 5, 10, 2, 24, 14, 2, 2, 14, 2, 4, 2, 5, 4, 4, 20, 2, 6, 8, 2, 4, 2, 14, 6, 2, 5, 2, 56, 4, 4, 4, 1, 4, 15, 22, 7, 4, 4, 6, 4, 12, 6, 2, 1, 8, 8, 6, 8, 4, 2, 2, 4, 16, 4, 16, 4, 11, 4, 16, 2, 4, 18, 10, 6, 4, 10, 5, 4, 4, 2, 1, 2, 6, 7, 2, 1, 12, 15, 6, 8, 3, 10, 6, 6, 15, 2, 2, 22, 2, 2, 4, 14, 4, 4, 2, 4, 4, 7, 2, 3, 4, 20, 4, 6, 6, 6, 6, 2, 2, 2, 40, 5, 57, 16, 9, 39, 23, 14, 4, 2, 4, 4, 1, 3, 2, 14, 18, 14, 4, 2, 4, 5], dtype=float)
d = np.array([85204, 52148, 51493, 49650, 71615, 40589, 64427, 82750, 106819, 74787, 85377, 103583, 105085, 182878, 62091, 57892, 50195, 93694, 109417, 73217, 55927, 72714, 63947, 55296, 90402, 88750, 65165, 45275, 96197, 25340, 21605, 35532, 47485, 26538, 24425, 23869, 26354, 22754, 21407, 55827, 21632, 22906, 28906, 24859, 21307, 30817, 17375, 9858, 18313, 232498, 19294, 97136, 51202, 37572, 54557, 70766, 57097, 114500, 56602, 45331, 40991, 39157, 38712, 55311, 31137, 97381, 34769, 27199, 26256, 54649, 55692, 20187, 29983, 38937, 18890, 27164, 40477, 26669, 17575, 29507, 24172, 32419, 20765, 22351, 47418, 32246, 30448, 19956, 12941, 14893, 13225, 35730, 19355, 25819, 54119, 110946, 65895, 87889, 39733, 83585, 64361, 59842, 38631, 58091, 48131, 103965, 93565, 79384, 45332, 95638, 95917, 36616, 48542, 33289, 41184, 24751, 30682, 43297, 24261, 32458, 30735, 25613, 26687, 44623, 21578, 26780, 27357, 24233, 43731, 65909, 39591, 46440, 34500, 63942, 37372, 110805, 100447, 36929, 90949, 82508, 50365, 60669, 85053, 99502, 31699, 43013, 33139, 20893, 25020, 21772, 19161, 19574, 25259, 21669, 29567, 25830, 28928, 39378, 18625, 35365, 20150, 28809, 20157, 12458, 18612, 18951, 36695, 30367, 27935, 17827, 38211, 47211, 32846, 50263, 33010, 64535, 38932, 64244, 47936, 65004, 93956, 121679, 135349, 92521, 57248, 60545, 96179, 61296, 82527, 129102, 57743, 27099, 27526, 77945, 47691, 72254, 30493, 31201, 30027, 18517, 51830, 45119, 34492, 19145, 28046, 39817, 32189, 26649, 22680, 35697, 20009, 27063, 22006, 17429, 21850, 36964, 22783, 29280, 17551, 16710, 21856, 56571, 29912, 78352, 47126, 119015, 58523, 53470, 117603, 113078, 90366, 152835, 84910, 44449, 48737, 90380, 36751, 21517, 61106, 43713, 31013, 32848, 35680, 34321, 28032, 28998, 21993, 17442, 25893, 35886, 18890, 21404, 21068, 31837, 24098, 22238, 22874, 79276, 59551, 56030, 51519, 56161, 69342, 44812, 75656, 147183, 101918, 117934, 101308, 48944, 72581, 83213, 39094, 45053, 41858, 93766, 24785, 44300, 32312, 44351, 46431, 22052, 22717, 30312, 17618, 18482, 32053, 30781, 23381, 25085, 38552, 17232, 7446, 31514, 31803, 52601, 59064, 57327, 83281, 52313, 102054, 81384, 46131, 41147, 92202, 83376, 71833, 81751, 50042, 139309, 88766, 65899, 73897, 13498, 24365, 54974, 55356, 55169, 70458, 34189, 22680, 30344, 29500, 29508, 20756, 23004, 28998, 19748, 24246, 45116, 19455, 45974, 34776, 28558, 52504, 100644, 36647, 34962, 62670, 86643, 72854, 149450, 144966, 64755, 160236, 107847, 169807, 45338, 118870, 59907, 54753, 95093, 53541, 40316, 41518, 30616, 62899, 24178, 32149, 31800, 16248, 26890, 23822, 24347, 21534, 32738, 34430, 26452, 22751, 24797, 32184, 21959, 14426, 21093, 22693, 35388, 38083, 20472, 96580, 48828, 41702, 80508, 59900, 69009, 150823, 50156, 124057, 79322, 99327, 37164, 48447, 73061, 23266, 60930, 32194, 27431, 48665, 23644, 23114, 19836, 20855, 26099, 18557, 34715, 69252, 35369, 24415, 27723, 75882, 80116, 93982, 93600, 48240, 30530, 70467, 48381, 68071, 111650, 49644, 48321, 67175, 36795, 38703, 45768, 25008, 101486, 118644, 28467, 27976, 40437, 35654, 36718, 37723, 27918, 21014, 39288, 36771, 35745, 59531, 15936, 18691, 13160, 22037, 20341, 26777, 33097, 21118, 10619, 51840, 37167, 37119, 17560, 95647, 49868, 102463, 43691, 58891, 117927, 59001, 55306, 60960, 62106, 37680, 71267, 69463, 103003, 39256, 57695, 36497, 32099, 42077, 29675, 39168, 22153, 65598, 45505, 27221, 24101, 35195, 25168, 21496, 20437, 27609, 22490, 31317, 23305, 20425, 14547, 32549, 37037, 19987, 49525, 81588, 74676, 88173, 63231, 124179, 88269, 66430, 48514, 63505, 52358, 45551, 32175, 16347, 47466, 49138, 32070, 19425, 39898, 25584, 29265, 28505, 36205], dtype=float)
n = len(r)
q, d, r, p = bd.burst_detection(r,d,n,s=2,gamma=1,smooth_win=1)
bursts = bd.enumerate_bursts(q, 'burstLabel')
weighted_bursts = bd.burst_weights(bursts,r,d,p)
print('observed probabilities: ')
print(str(r/d))
print('optimal state sequence: ')
print(str(q.T))
print('baseline probability: ' + str(p[0]))
print('bursty probability: ' + str(p[1]))
print('weighted bursts:')
print(weighted_bursts)
然后我尝试只将10或20个元素放到r和d,它可以工作。当它有524个元素时,它会显示这种错误。 有人知道为什么吗? 谢谢!
答案 0 :(得分:1)
它似乎是burst_detection
包中的错误。我不知道该行应该做什么,但行
q[t] = np.where(cost[t,:] == min(cost[t,:]))
将尝试将左侧q[t]
设置为右侧,这是一个单个数组元素。当只有一个成本等于最低成本时,这种方法很好,但是当两个状态共享相同的最低成本时,右侧将具有长度2.要让它只选择其中一个最佳状态,您需要将该行更改为
q[t] = np.where(cost[t,:] == min(cost[t,:]))[0]
您可以通过转到安装软件包的任何位置并手动编辑该行来执行此操作。如果你认为软件包维护者会解决这个问题,你也可以在GitHub上的跟踪器上打开一个问题。