对于cnn-lstm,

时间:2019-03-28 16:54:08

标签: python neural-network deep-learning lstm recurrent-neural-network

下面是我的cnn-lstm体系结构。

model = Sequential()
model.add(TimeDistributed(Conv2D(64, (2, 2), padding='same'), 
                          input_shape=(10,128, 128 ,1))) 
model.add(BatchNormalization())
model.add(Activation("relu"))

model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))
model.add(TimeDistributed(Conv2D(32, (2, 2), padding='same')))
model.add(BatchNormalization())
model.add(Activation("relu"))

model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))
model.add(TimeDistributed(Conv2D(16, (2, 2), padding='same')))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))

model.add(TimeDistributed(Flatten()))
model.add(LSTM(units=64, return_sequences=True))


model.add(TimeDistributed(Reshape((8, 8, 1))))
model.add(TimeDistributed(UpSampling2D((2,2))))
model.add(TimeDistributed(Conv2D(16, (2, 2), padding='same')))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(TimeDistributed(UpSampling2D((2,2))))
model.add(TimeDistributed(Conv2D(32, (2, 2), padding='same')))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(TimeDistributed(UpSampling2D((2,2))))
model.add(TimeDistributed(Conv2D(64, (2, 2), padding='same')))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(TimeDistributed(UpSampling2D((2,2))))
model.add(TimeDistributed(Conv2D(1, (2, 2), padding='same')))

model.compile(optimizer='RMSProp', loss='mse', metrics=['mean_absolute_error', 'mean_absolute_percentage_error','mean_squared_error','accuracy'])

data = np.load(r"/content/boxing_d1.npy")
print (data.shape)
(x_train,x_test) = train_test_split(data)

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)
history = model.fit(x_train, x_train,
                epochs=100,
                batch_size=1,
                shuffle=False,
                validation_data=(x_test, x_test))

encoded_imgs = model.predict(x_test)
decoded_imgs = model.predict(encoded_imgs)

我正在尝试使用cnn-lstm提取视频的压缩表示形式,以便将其用于使用kmeans进行分类的目的。这是经过一组训练后的输出。 指标为何会波动,如输出所示。我也训练了类似的模型,但是进一步的训练会增加验证和训练损失。如果有人想看一个链接,我会发布一个链接。

model accuracy after first set model loss after first set model accuracy after second set model loss after second set

下面给出的是最后20个时代

  

(20、10、128、128、1)   (15、10、128、128、1)   (5,10,128,128,1)

     

史诗80/100   15/15 [==============================]-14s 916ms / step-损失:0.0288-mean_absolute_error:0.1253-mean_absolute_percentage_error :24.7657-均方误差:0.0288-acc:1.2370e-04-val_loss:0.0869-val_mean_absolute_error:0.2499-val_mean_absolute_percentage_error:36.1645-val_mean_squared_error:0.0869-val_acc:9.9976e-04   时代81/100   15/15 [==============================]-14s 923ms / step-损失:0.0283-mean_absolute_error:0.1237-mean_absolute_percentage_error :24.5083-均方误差:0.0283-acc:1.3346e-04-val_loss:0.0859-val_mean_absolute_error:0.2490-val_mean_absolute_percentage_error:36.0187-val_mean_squared_error:0.0859-val_acc:9.5459e-04   时代82/100   15/15 [==============================]-14s 927ms / step-损失:0.0284-mean_absolute_error:0.1239-mean_absolute_percentage_error :24.5969-均方误差:0.0284-acc:1.2207e-04-val_loss:0.0841-val_mean_absolute_error:0.2464-val_mean_absolute_percentage_error:35.9061-val_mean_squared_error:0.0841-val_acc:9.5093e-04   时代83/100   15/15 [==============================]-14s 921ms / step-损失:0.0283-mean_absolute_error:0.1247-mean_absolute_percentage_error :24.8455-均方误差:0.0283-acc:1.3468e-04-val_loss:0.0827-val_mean_absolute_error:0.2449-val_mean_absolute_percentage_error:35.9207-val_mean_squared_error:0.0827-val_acc:0.0010   时代84/100   15/15 [==============================]-14s 917ms / step-损失:0.0281-mean_absolute_error:0.1243-mean_absolute_percentage_error :24.9100-均方误差:0.0281-acc:1.3102e-04-val_loss:0.0854-val_mean_absolute_error:0.2480-val_mean_absolute_percentage_error:36.2074-val_mean_squared_error:0.0854-val_acc:9.4849e-04   时代85/100   15/15 [==============================]-14s 920ms / step-损失:0.0282-mean_absolute_error:0.1240-mean_absolute_percentage_error :24.9614-均方误差:0.0282-acc:1.1556e-04-val_loss:0.0852-val_mean_absolute_error:0.2479-val_mean_absolute_percentage_error:36.0072-val_mean_squared_error:0.0852-val_acc:7.9712e-04   时代86/100   15/15 [==============================]-14秒914毫秒/步-损耗:0.0279-平均值绝对误差:0.1228-平均值绝对误差:24.6722-均方误差:0.0279-acc:1.1230e-04-val_loss:0.0847-val_mean_absolute_error:0.2476-val_mean_absolute_percentage_error:36.1600-val_mean_squared_error:0.0847-val_acc:8.1421e-04   时代87/100   15/15 [==============================]-14s 919ms / step-损失:0.0273-mean_absolute_error:0.1209-mean_absolute_percentage_error :24.2113-均方误差:0.0273-acc:1.2614e-04-val_loss:0.0818-val_mean_absolute_error:0.2442-val_mean_absolute_percentage_error:36.0303-val_mean_squared_error:0.0818-val_acc:9.2407e-04   时代88/100   15/15 [==============================]-14s 919ms / step-损失:0.0278-mean_absolute_error:0.1221-mean_absolute_percentage_error :24.5280-均方误差:0.0278-acc:1.1678e-04-val_loss:0.0823-val_mean_absolute_error:0.2449-val_mean_absolute_percentage_error:36.0732-val_mean_squared_error:0.0823-val_acc:9.6069e-04   时代89/100   15/15 [==============================]-14s 918ms / step-损失:0.0273-mean_absolute_error:0.1219-mean_absolute_percentage_error :24.4659-均方误差:0.0273-acc:1.1393e-04-val_loss:0.0815-val_mean_absolute_error:0.2442-val_mean_absolute_percentage_error:36.1811-val_mean_squared_error:0.0815-val_acc:0.0010   时代90/100   15/15 [==============================]-14秒914毫秒/步-损耗:0.0273-均值绝对误差:0.1230-均值绝对误差:24.8766-均方误差:0.0273-acc:1.1922e-04-val_loss:0.0783-val_mean_absolute_error:0.2404-val_mean_absolute_percentage_error:35.8445-val_mean_squared_error:0.0783-val_acc:0.0010   时代91/100   15/15 [==============================]-14s 918ms / step-损失:0.0271-mean_absolute_error:0.1204-mean_absolute_percentage_error :24.3801-均方误差:0.0271-acc:1.3062e-04-val_loss:0.0814-val_mean_absolute_error:0.2436-val_mean_absolute_percentage_error:35.6696-val_mean_squared_error:0.0814-val_acc:0.0010   时代92/100   15/15 [==============================]-14s 913ms / step-损失:0.0270-mean_absolute_error:0.1201-mean_absolute_percentage_error :24.3027-均方误差:0.0270-acc:1.3224e-04-val_loss:0.0813-val_mean_absolute_error:0.2435-val_mean_absolute_percentage_error:35.6484-val_mean_squared_error:0.0813-val_acc:0.0010   时代93/100   15/15 [==============================]-14s 915ms / step-损失:0.0266-mean_absolute_error:0.1200-mean_absolute_percentage_error :24.2725-均方误差:0.0266-acc:1.2614e-04-val_loss:0.0814-val_mean_absolute_error:0.2436-val_mean_absolute_percentage_error:35.6852-val_mean_squared_error:0.0814-val_acc:0.0010   时代94/100   15/15 [==============================]-14s 915ms / step-损失:0.0271-mean_absolute_error:0.1203-mean_absolute_percentage_error :24.7179-均方误差:0.0271-acc:1.1800e-04-val_loss:0.0837-val_mean_absolute_error:0.2459-val_mean_absolute_percentage_error:36.1841-val_mean_squared_error:0.0837-val_acc:0.0010   时代95/100   15/15 [==============================]-14s 923ms / step-损失:0.0278-mean_absolute_error:0.1216-mean_absolute_percentage_error :24.8603-均方误差:0.0278-acc:1.2207e-04-val_loss:0.0823-val_mean_absolute_error:0.2444-val_mean_absolute_percentage_error:35.7832-val_mean_squared_error:0.0823-val_acc:0.0010   时代96/100   15/15 [==============================]-14s 925ms / step-损失:0.0274-mean_absolute_error:0.1213-mean_absolute_percentage_error :24.7304-均方误差:0.0274-acc:1.2451e-04-val_loss:0.0816-val_mean_absolute_error:0.2441-val_mean_absolute_percentage_error:36.0207-val_mean_squared_error:0.0816-val_acc:0.0010   时代97/100   15/15 [==============================]-14s 921ms / step-损失:0.0270-mean_absolute_error:0.1193-mean_absolute_percentage_error :24.3427-均方误差:0.0270-acc:1.3021e-04-val_loss:0.0821-val_mean_absolute_error:0.2444-val_mean_absolute_percentage_error:35.8657-val_mean_squared_error:0.0821-val_acc:0.0010   时代98/100   15/15 [==============================]-14s 911ms / step-损失:0.0268-mean_absolute_error:0.1201-mean_absolute_percentage_error :24.4321-均方误差:0.0268-acc:1.1556e-04-val_loss:0.0813-val_mean_absolute_error:0.2440-val_mean_absolute_percentage_error:36.0604-val_mean_squared_error:0.0813-val_acc:0.0010   时代99/100   15/15 [==============================]-14s 919ms / step-损失:0.0264-mean_absolute_error:0.1194-mean_absolute_percentage_error :24.4055-均方误差:0.0264-acc:1.1719e-04-val_loss:0.0817-val_mean_absolute_error:0.2443-val_mean_absolute_percentage_error:35.9919-val_mean_squared_error:0.0817-val_acc:0.0010   时代100/100   15/15 [==============================]-14s 909ms / step-损失:0.0262-mean_absolute_error:0.1189-mean_absolute_percentage_error :24.3198-均方误差:0.0262-acc:1.1800e-04-val_loss:0.0816-val_mean_absolute_error:0.2441-val_mean_absolute_percentage_error:35.9696-val_mean_squared_error:0.0816-val_acc:0.0010

体系结构有问题吗? 我应该增加编号吗每层神经元的数量或没有增加。层(我已经尝试过了,但是即使使用colab也很慢并且不容易收敛)?抱歉,长时间的输出。谢谢您的回复。

0 个答案:

没有答案