如果这个问题看起来很愚蠢,我事先表示歉意,但是我试图理解有关LSTM类型ML算法的machinelearningmastery blog,更具体地说是关于如何重塑输入数据的形式。而且我在这方面没有很多智慧,也没有CS学位
Jason在LSTM CNN
博客的一半处谈到:
第一步是将输入序列分为多个子序列 可以由CNN模型处理。例如,我们可以先拆分我们的 通过四个步骤将时间序列数据单变量化为输入/输出样本 作为输入,一个作为输出。然后可以将每个样本分成两部分 子样本,每个样本有两个时间步长。 CNN可以解释每个 两个时间步的子序列,并提供一个时间序列 LSTM模型子序列的解释,以作为 输入。
我们可以将此参数化并定义子序列的数量为 n_seq,每个子序列的时间步数为n_steps。的 然后可以将输入数据重塑为所需的结构:
[samples, subsequences, timesteps, features]
我的问题是仅将数据整形为4 steps
的要求吗?还是可以更大?下面的代码将尝试打印数组,我在git帐户上使用自己的数据示例here。
import pandas as pd
import numpy as np
# univariate data preparation
from numpy import array
df = pd.read_csv("trainData.csv")
df = df[['kW']].shift(-1)
df = df.dropna()
raw_seq = df.values.tolist()
# split a univariate sequence into samples
def split_sequence(sequence, n_steps):
X, y = list(), list()
for i in range(len(sequence)):
# find the end of this pattern
end_ix = i + n_steps
# check if we are beyond the sequence
if end_ix > len(sequence)-1:
break
# gather input and output parts of the pattern
seq_x, seq_y = sequence[i:end_ix], sequence[end_ix]
X.append(seq_x)
y.append(seq_y)
return array(X), array(y)
# define input sequence
#raw_seq = [10, 20, 30, 40, 50, 60, 70, 80, 90]
# choose a number of time steps
n_steps = 4
# split into samples
X, y = split_sequence(raw_seq, n_steps)
# reshape from [samples, timesteps] into [samples, subsequences, timesteps, features]
n_features = 1
n_seq = 2
n_steps = 2
X = X.reshape((X.shape[0], n_seq, n_steps, n_features))
# summarize the data
for i in range(len(X)):
print(X[i], y[i])
上面的代码有效,但是当我更改n_steps = 7
(从4开始)时,出现此形状错误。
File "convArray.py", line 39, in <module>
X = X.reshape((X.shape[0], n_seq, n_steps, n_features))
ValueError: cannot reshape array of size 2499 into shape (357,2,2,1)
我想尝试使用7个时间步长的原因是,我正在试验的数据是建筑物的每天电力需求单位,而一周7天是理想的实验时间步长!
非常感谢任何提示
答案 0 :(得分:0)
在这种情况下,错误中提到了问题。问题是您正在尝试将数组重塑为特定的形状,但是这是不可能的。数组 X 具有2499个元素,不能将2499整形为(357,2,2,1)形状。形状中数字的乘积是元素总数。 (357,2,2,1)具有357 * 2 * 2 * 1 = 1428个元素。
因此,当n_steps = 2
时,您的代码返回了一个总共包含1428个元素的数组。
我认为在您的情况下,它取决于raw_seq
,因为它的长度决定split_sequence()
内部的for循环运行的次数以及返回的数组。raw_seq
取决于数据,因此对于此数据集,n_steps
可能会受到限制。
不过,我不确定100%。