Keras自定义度量标准邻居类的准确性

时间:2018-04-12 15:20:01

标签: python tensorflow keras

我正在尝试在Keras模型中添加自定义指标。基本上,自定义度量不仅考虑了预测的类与真实类相同,而且当预测的类是真实类的邻居时也考虑正确的预测。但是,根据我写的代码,它给了我这个错误:

AttributeError:'int'对象没有属性'dtype'

以下是代码:

from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.models import load_model

input_shape = (32, 32, 3)
model = Sequential()
model.add(Conv2D(64, (3, 3), input_shape=input_shape, padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dense(4096, activation='relu'))
model.add(Dense(8, activation='softmax'))

def one_off_accuracy(y_true, y_pred):
    return 1 if y_pred == y_true or y_pred == y_true + 1 or y_pred == y_true - 1 else 0

model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy', one_off_accuracy])

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

training_set = train_datagen.flow_from_directory('dataset/training',
                                                target_size=(32, 32),
                                                batch_size=1024,
                                                class_mode='categorical')


test_set = test_datagen.flow_from_directory('dataset/test',
                                            #target_size=(224, 224),
                                            target_size=(32, 32),
                                            #batch_size=256,
                                            batch_size=1024,
                                            class_mode='categorical')

model.fit_generator(training_set,
                        steps_per_epoch= (20985/1024), 
                        epochs=25,
                        validation_data=test_set,
                        validation_steps= (5248/1024))

import numpy as np
from keras.preprocessing import image
test_image = image.load_img('dataset/1398.10693712846_42bd993ea5_o.jpg', target_size = (32, 32))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis = 0)
training_set.class_indices
result = model.predict_classes(test_image)

1 个答案:

答案 0 :(得分:0)

尝试返回一个keras变量,而不是返回一个整数。我认为keras期待的是:

void MainWindow::on_actionSave_triggered()  // Saves the file
{
    if (!saveFile.isEmpty())
    {
        QString wSave;
        QFile file(saveFile);
        wSave = ui->textEdit->toPlainText();
        wSave.replace("\n", "\r\n");
        wSave.replace("\r\r\n", "\r\n");
        if (!file.open(QIODevice::WriteOnly))   // This if statement makes sure that the file can be written to.
        {
            // error message

        }
        else
        {
      QTextStream stream(&file);        // Prepares the file to receive the QTextStream
      if (encodeString == "Plain Text")     //The next set of "if" statements set the encoding that the file will be saved in. This is set by the combo box in the text editor.
            {

                qDebug() << wSave;
                stream << wSave;
                stream.flush();
                file.close();
            }

            if (encodeString == "UTF-8")
            {

                stream.setCodec("UTF-8");
                stream.setGenerateByteOrderMark(true);
                stream << wSave;
                stream.flush();
                file.close();
            }

            if (encodeString == "UTF-16")
            {
                stream.setCodec("UTF-16");
                stream.setGenerateByteOrderMark(true);
                stream << wSave;
                stream.flush();
                file.close();
            }

            if (encodeString == "UTF-16BE")
            {
                stream.setCodec("UTF-16BE");
                stream.setGenerateByteOrderMark(true);
                stream << wSave;
                stream.flush();
                file.close();
            }

            if (encodeString == "UTF-32")
            {
                stream.setCodec("UTF-32");
                stream.setGenerateByteOrderMark(true);
                stream << wSave;
                stream.flush();
                file.close();
            }

            if (encodeString == "UTF-32BE")
            {
                stream.setCodec("UTF-32BE");
                stream.setGenerateByteOrderMark(true);
                stream << wSave;
                stream.flush();
                file.close();
            }


        }
    }


}