我的tflite模型给了我错误的预测

时间:2019-12-28 17:17:49

标签: python android-studio tensorflow machine-learning keras

我将我的keras模型(.h5)转换为.tflite模型,当我给它提供图像时,将其放入android应用程序后并不能给出正确的预测。基本的keras模型具有95%的准确性。

我的模特

model.add(Conv2D(128, (3, 3), input_shape=(64, 64, 3), activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(BatchNormalization())
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
model.add(Flatten())
model.add(Dropout(.50))
model.add(Dense(500, activation='relu'))
model.add(Dropout(.50))
model.add(Dense(100, activation='relu'))
model.add(Dropout(.50))
model.add(Dense(9, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

batch_size = 60
pic_size = 64

train_datagen = ImageDataGenerator(rescale=1. / 255)

test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
        '/DATASET/Training_Samples',
       target_size=(64, 64),
        color_mode='rgb',
        batch_size=batch_size,
        class_mode="categorical",
        shuffle=True)

validation_generator = test_datagen.flow_from_directory(
        '/DATASET/Test_Samples',
        target_size=(64, 64),
        color_mode='rgb',
        batch_size=batch_size,
        class_mode="categorical",
        shuffle=False)



history = model.fit_generator(generator=train_generator,
                            steps_per_epoch=train_generator.n//train_generator.batch_size,
                            epochs=150,
                            validation_data=validation_generator,
                            validation_steps = validation_generator.n//validation_generator.batch_size)

converter = tf.lite.TFLiteConverter.from_keras_model_file("model.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

在制作应用程序时我使用了这个tutorial

0 个答案:

没有答案