每当我尝试使用4d numpy数组训练CNN模型时,就会出现上述错误。模型构建器的功能如下:
def build_model(input_shape, LR=.001, phone_count=43):
#build the network
model= keras.Sequential()
#conv layer 1
#model.add(keras.layers.Conv2D(64,(3,3), activation='relu', input_shape=input_shape, kernel_regularizer=keras.regularizers.l2(0.001), data_format="channels_first"))
model.add(keras.layers.Conv2D(64,(3,3), activation='relu', input_shape=input_shape, kernel_regularizer=keras.regularizers.l2(0.001)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((3,3), strides=(2,2), padding='same'))
#conv layer 2
model.add(keras.layers.Conv2D(32,(3,3), activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((3,3), strides=(2,2), padding='same'))
#conv layer 3
model.add(keras.layers.Conv2D(32,(2,2), activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((2,2), strides=(2,2), padding='same'))
#Flatten the output
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dropout(0.3))
#softmax layer
model.add(keras.layers.Dense(phone_count, activation='softmax'))
#compile the model
the_optimizer= keras.optimizers.Adamax(learning_rate=LR)
model.compile(optimizer=the_optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
return model
此功能用于从npy文件中提取数据作为数据框(以便更轻松地删除不必要的数据/行),然后转换为numpy数组以训练和测试模型。
def prep_data(shape_num=3, train_min=50):
#read data from npy file
data1=np.load(os.path.join(rootdir+"/Train(Small Win).npy"), allow_pickle=True)
data2=np.load(os.path.join(rootdir+"/Test(Small Win).npy"), allow_pickle=True)
train= pd.DataFrame(data1, columns=['Phone', 'Signal'])
test= pd.DataFrame(data2, columns=['Phone', 'Signal'])
train['Phone']=train['Phone'].astype(int)
test['Phone']=test['Phone'].astype(int)
#shuffle data before splitting
train= sklearn.utils.shuffle(train)
train= train.reset_index(drop=True)
test= sklearn.utils.shuffle(test)
test= test.reset_index(drop=True)
#filter unnecessary data based on criteria
train, test= data_filter(train, test, shape_num, train_min)
#split data
test, validation= sklearn.model_selection.train_test_split(test, test_size=0.4, shuffle=False)
validation=validation.reset_index(drop=True)
x_train, y_train, x_test, y_test= train['Signal'], train['Phone'], test['Signal'], test['Phone']
x_validation, y_validation= validation['Signal'], validation['Phone']
x_train=x_train.to_numpy()
y_train=y_train.to_numpy()
x_test=x_test.to_numpy()
y_test=y_test.to_numpy()
x_validation=x_validation.to_numpy()
y_validation=y_validation.to_numpy()
#convert from 2D -> 3D
# x_test=x_test[..., np.newaxis]
# x_train=x_train[..., np.newaxis]
# x_validation=x_validation[..., np.newaxis]
return x_train, y_train, x_test, y_test, x_validation, y_validation
这是我用来实际训练模型的主要功能,并且出现以下错误。
def main():
x_train, y_train, x_test, y_test, x_validation, y_validation= prep_data(shape_num=12, train_min=5)
#build the model
model= build_model((12, 16, 1), phone_count=42)
#train the model
model.fit(x_train, y_train, epochs=40, batch_size=32,
validation_data=(x_validation, y_validation))
#evaluate the model
error, accuracy= model.evaluate(x_test, y_test)
print(f"Test error: {error}, Test accuracy: {accuracy}")
#save model
model.save(model.h5)
我尝试将tf.convert_to_tensor(x_train)
用于所有训练和验证数据,但返回的错误相同。下面,我验证了所有元素的shape(16,12)和type(
print(x_train.shape)
print(x_train[0].shape)
print(type(x_train[0]))
print(x_validation[0].shape)
print(type(x_validation[0]))
print(x_train.__len__())
print(y_train.__len__())
print(x_train[0])
for i in range(x_train.__len__()):
if x_train[i].shape != x_train[0].shape or type(x_train[i]) != type(x_train[0]):
print("Index ",i," is: ", x_train[i].shape)
print("type is: ", x_validation[i].shape)
for i in range(x_validation.__len__()):
if x_validation[i].shape != x_validation[0].shape or type(x_validation[i]) != type(x_validation[0]):
print("Index ",i," is: ", x_validation[i].shape)
print("type is: ", x_validation[i].shape)
结果:
(9592,)
(16, 12)
<class 'numpy.ndarray'>
(16, 12)
<class 'numpy.ndarray'>
9592
9592
[[-6.21559034e+02 -6.03861092e+02 -6.44275070e+02 -6.21108087e+02
-6.33902502e+02 -6.10202319e+02 -6.38066649e+02 -6.14831453e+02
-6.08786659e+02 -5.54751120e+02 -5.59108425e+02 -5.90645072e+02]
[ 1.62488404e+02 1.97702161e+02 1.93338819e+02 2.00768320e+02
1.96443196e+02 2.06493442e+02 1.95534267e+02 2.05817857e+02
1.97520058e+02 1.80149207e+02 1.51495948e+02 1.36431422e+02]
[-1.68000882e+01 -1.71383988e+01 -2.19247949e+01 -3.50331972e+01
-3.22168674e+01 -3.66609409e+01 -2.59634154e+01 -3.14709250e+01
-4.25033816e+01 -4.43741521e+01 -3.70234923e+01 -2.95822967e+01]
[ 4.10432947e+01 4.87035804e+01 5.52045579e+01 5.74210124e+01
5.85131273e+01 5.99729371e+01 5.79334290e+01 5.10461554e+01
4.97165940e+01 5.25213143e+01 5.17593922e+01 5.76097801e+01]
[-3.43969005e+01 -5.62415603e+01 -6.69549272e+01 -7.22399067e+01
-7.08110925e+01 -7.29823164e+01 -7.44342335e+01 -7.85082031e+01
-7.90021868e+01 -9.42318038e+01 -1.07264419e+02 -7.69748910e+01]
[-3.27785843e+01 -4.15082566e+01 -3.81622595e+01 -4.26821583e+01
-4.28947818e+01 -4.35927895e+01 -4.20522233e+01 -4.07253579e+01
-3.54880234e+01 -1.24468099e+01 7.66788474e+00 2.93352370e+01]
[ 8.13909911e+00 8.86097919e+00 4.62275236e+00 6.80030771e+00
7.73078526e+00 3.24659302e+00 4.32970141e+00 1.21840123e+00
-5.97645909e-01 -1.59230216e+01 -1.83660643e+01 -2.36423881e+01]
[ 4.72338015e+00 1.05095395e+01 1.83643214e+01 1.62532773e+01
1.63903496e+01 1.48676056e+01 1.80896153e+01 1.67417707e+01
1.57838347e+01 2.75592116e+01 3.30267593e+01 1.33186030e+01]
[-1.70412263e+01 -2.89931507e+01 -3.17646436e+01 -3.13428307e+01
-3.21056855e+01 -3.13025217e+01 -2.92779319e+01 -3.28383410e+01
-2.86348024e+01 -3.56308366e+01 -3.72640420e+01 -2.96905489e+01]
[ 1.23463742e+01 2.39805333e+00 -3.08857470e+00 -3.72663302e+00
-5.67306760e+00 -6.75569345e+00 -6.68106808e+00 -8.97175994e+00
-1.01855553e+01 -1.85199448e+01 -1.87053295e+01 -1.61857339e+00]
[ 4.64371993e+00 -6.73981799e+00 -1.21180531e+01 -6.47924891e+00
-4.09369632e+00 -7.61701294e-01 8.39204718e-01 4.33606029e+00
2.50502410e+00 5.49674950e+00 1.75774383e+01 3.36114716e+01]
[-1.34966338e+01 -2.31340623e+01 -2.69256468e+01 -2.52562186e+01
-2.49909368e+01 -3.01723141e+01 -2.65044181e+01 -2.28321862e+01
-1.76064424e+01 -1.96508533e+01 -4.32244206e+00 1.07479438e+01]
[ 2.00379681e+00 1.15979206e+00 5.48741698e+00 9.24981351e+00
1.23602788e+01 7.61237738e+00 7.21907492e+00 6.76921775e+00
1.54061820e+01 1.40574085e+01 2.71901434e+01 1.73292818e+01]
[-3.95202463e-01 -4.64221159e+00 -2.56725829e+00 -9.75469429e+00
-1.53314060e+00 -7.13854658e+00 6.17673619e-01 1.11667664e-01
5.55916296e+00 1.19341530e+01 2.47728353e+01 4.94114903e+00]
[ 1.13522109e+01 1.48745532e+01 1.94529171e+01 1.26263470e+01
1.71353642e+01 1.04325893e+01 1.76135506e+01 1.50917866e+01
2.00581224e+01 1.27428903e+01 1.49328414e+01 -4.96430014e+00]
[ 1.08379527e+01 1.16791172e+01 9.86733739e+00 8.62182114e+00
1.09428703e+01 1.16916648e+01 1.07482436e+01 9.04880045e+00
8.05048866e+00 -1.34804509e+00 -5.49922438e+00 -5.12267269e+00]]
如果您知道我为什么会出现此错误,请告诉我
(对于那些对数据使用感兴趣的人,从用于分类的语音数据中提取其MFCCs(16))