我试图与this tutorial一起使用。我陷入了这段代码,并伴随着这张图片:
class DataGeneratorSeq(object):
def __init__(self,prices,batch_size,num_unroll):
self._prices = prices
self._prices_length = len(self._prices) - num_unroll
self._batch_size = batch_size
self._num_unroll = num_unroll
self._segments = self._prices_length //self._batch_size
self._cursor = [offset * self._segments for offset in range(self._batch_size)]
我对那里发生的事情感到困惑。为什么似乎还有其他批次的批次?它们是通过什么逻辑生成的?
我什至试图做一个小例子来弄清楚,但我仍然不知道发生了什么事:
batch_size = 6
prices = [10,11,12,13,14,15,16,17,18,19]
num_unroll = 4
prices_length = len(prices) - num_unroll
print('prices_length =', prices_length)
segments = prices_length // batch_size
print('segments =', segments)
cursor = [offset * segments for offset in range(batch_size)]
print('cursor :', cursor, '\n')
dg2 = DataGeneratorSeq(prices, batch_size, num_unroll)
u_data2, u_labels2 = dg2.unroll_batches()
for ui,(dat,lbl) in enumerate(zip(u_data2,u_labels2)):
print('\n\nUnrolled index %d'%ui)
dat_ind = dat
lbl_ind = lbl
print('\tInputs: ',dat )
print('\n\tOutput:',lbl)
输出:
prices_length = 6
segments = 1
cursor : [0, 1, 2, 3, 4, 5]
b: 0 , [10. 0. 0. 0. 0. 0.]
b: 1 , [10. 11. 0. 0. 0. 0.]
b: 2 , [10. 11. 12. 0. 0. 0.]
b: 3 , [10. 11. 12. 13. 0. 0.]
b: 4 , [10. 11. 12. 13. 14. 0.]
b: 5 , [10. 11. 12. 13. 14. 14.]
ui: 0 , [array([10., 11., 12., 13., 14., 14.], dtype=float32)]
ui: 0 , [array([14., 13., 15., 16., 16., 17.], dtype=float32)]
b: 0 , [11. 0. 0. 0. 0. 0.]
b: 1 , [11. 12. 0. 0. 0. 0.]
b: 2 , [11. 12. 13. 0. 0. 0.]
b: 3 , [11. 12. 13. 14. 0. 0.]
b: 4 , [11. 12. 13. 14. 12. 0.]
b: 5 , [11. 12. 13. 14. 12. 15.]
ui: 1 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32)]
ui: 1 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)]
b: 0 , [12. 0. 0. 0. 0. 0.]
b: 1 , [12. 13. 0. 0. 0. 0.]
b: 2 , [12. 13. 14. 0. 0. 0.]
b: 3 , [12. 13. 14. 13. 0. 0.]
b: 4 , [12. 13. 14. 13. 13. 0.]
b: 5 , [12. 13. 14. 13. 13. 10.]
ui: 2 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32)]
ui: 2 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)]
b: 0 , [13. 0. 0. 0. 0. 0.]
b: 1 , [13. 14. 0. 0. 0. 0.]
b: 2 , [13. 14. 12. 0. 0. 0.]
b: 3 , [13. 14. 12. 14. 0. 0.]
b: 4 , [13. 14. 12. 14. 14. 0.]
b: 5 , [13. 14. 12. 14. 14. 11.]
ui: 3 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)]
ui: 3 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32)]
---------- [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------
[(array([10., 11., 12., 13., 14., 14.], dtype=float32), array([14., 13., 15., 16., 16., 17.], dtype=float32)), (array([11., 12., 13., 14., 12., 15.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)), (array([12., 13., 14., 13., 13., 10.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)), (array([13., 14., 12., 14., 14., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32))]
<enumerate object at 0x7fb354a12a68>
---------- [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------
为什么将prices_length
计算为价格数组的总长度减去此神秘的展开数?好像num_unroll
是批次数,但是图片让我感到困惑。这个segments
变量是什么?
我已经阅读了很多有关LSTM的其他教程,但是都没有深入介绍。而且我觉得我理解它在理论上应该是如何工作的,只是LSTM本身,但是我无法弄清楚这段代码。
这是整个班级代码:
class DataGeneratorSeq(object):
def __init__(self,prices,batch_size,num_unroll):
self._prices = prices
self._prices_length = len(self._prices) - num_unroll
self._batch_size = batch_size
self._num_unroll = num_unroll
self._segments = self._prices_length //self._batch_size
self._cursor = [offset * self._segments for offset in range(self._batch_size)]
def next_batch(self):
batch_data = np.zeros((self._batch_size),dtype=np.float32)
batch_labels = np.zeros((self._batch_size),dtype=np.float32)
for b in range(self._batch_size):
if self._cursor[b]+1>=self._prices_length:
#self._cursor[b] = b * self._segments
self._cursor[b] = np.random.randint(0,(b+1)*self._segments)
batch_data[b] = self._prices[self._cursor[b]]
batch_labels[b]= self._prices[self._cursor[b]+np.random.randint(0,5)]
self._cursor[b] = (self._cursor[b]+1)%self._prices_length
return batch_data,batch_labels
def unroll_batches(self):
unroll_data,unroll_labels = [],[]
init_data, init_label = None,None
for ui in range(self._num_unroll):
data, labels = self.next_batch()
unroll_data.append(data)
unroll_labels.append(labels)
return unroll_data, unroll_labels
def reset_indices(self):
for b in range(self._batch_size):
self._cursor[b] = np.random.randint(0,min((b+1)*self._segments,self._prices_length-1))
dg = DataGeneratorSeq(train_data,5,5)
u_data, u_labels = dg.unroll_batches()
for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):
print('\n\nUnrolled index %d'%ui)
dat_ind = dat
lbl_ind = lbl
print('\tInputs: ',dat )
print('\n\tOutput:',lbl)
编辑:好的,所以我阅读了有关展开LSTM网络的更多信息,并且num_unroll var现在很有意义。因此,它只是遍历我们想回顾的许多数据点,并将它们反馈到网络,同时在每一步进行更新。
但是仍然不确定“段”是什么以及为什么引入了一些随机性。那样会否破坏基于前n个点的预测点?