未捕获(承诺)错误:检查时出错:预期flatten_1_input具有形状[null,7,7,512],但数组的形状为[1,224,224,3]

时间:2019-09-04 00:07:12

标签: python keras tensorflow.js transfer-learning tensorflowjs-converter

我使用转移学习的概念对图像进行分类,我重用了https://towardsdatascience.com/transfer-learning-from-pre-trained-models-f2393f124751中提到的代码

该模型对于我的Jupyter笔记本中的数据(https://www.dropbox.com/s/esirpr6q1lsdsms/ricetransfer1.zip?dl=0)运行良好,但是在测试该模型之前,我正在重塑图像。

但是当我想使用TensorFlow在浏览器中运行相同的模型时,我使用了tfjs.converters.save_keras_model来保存我的模型。

from keras.applications import VGG16
import tensorflowjs as tfjs
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

img_width, img_height = 224, 224  # Default input size for VGG16

conv_base = VGG16(weights='imagenet', 
              include_top=False,
              #input_shape=(224, 224, 3))
              input_shape=(img_width, img_height, 3))

# Extract features
import os, shutil
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
train_size, validation_size, test_size = 148, 27, 31

datagen = ImageDataGenerator(rescale=1./255)
batch_size = 16
train_dir = "ricetransfer1/train"
validation_dir = "ricetransfer1/validation"
test_dir="ricetransfer1/test"
#indices = np.random.choice(range(len(X_train)))

def extract_features(directory, sample_count):
   #sample_count= X_train.ravel()

    features = np.zeros(shape=(sample_count, 7, 7, 512))  # Must be 
    equal to the output of the convolutional base
    labels = np.zeros(shape=(sample_count))
    # Preprocess data
    generator = datagen.flow_from_directory(directory,
                                        target_size=(img_width,img_height),
                                        batch_size = batch_size,
                                        class_mode='binary')
    # Pass data through convolutional base
    i = 0
    for inputs_batch, labels_batch in generator:
        features_batch = conv_base.predict(inputs_batch)
        features[i * batch_size: (i + 1) * batch_size] = features_batch
        labels[i * batch_size: (i + 1) * batch_size] = labels_batch
        i += 1
        if i * batch_size >= sample_count:
            break
return features, labels

train_features, train_labels = extract_features(train_dir, train_size)  
validation_features, validation_labels = extract_features(validation_dir, validation_size)
test_features, test_labels = extract_features(test_dir, test_size)


from keras import models
from keras import layers
from keras import optimizers

epochs = 1

 ricemodel = models.Sequential()
 ricemodel.add(layers.Flatten(input_shape=(7,7,512)),)
 ricemodel.add(layers.Dense(256, activation='relu', input_dim=(7*7*512)))
 ricemodel.add(layers.Dropout(0.5))
 ricemodel.add(layers.Dense(1, activation='sigmoid'))
 ricemodel.summary()

 ricemodel.compile(optimizer=optimizers.Adam(),
          loss='binary_crossentropy',
          metrics=['acc'])

 import os
 history=ricemodel.fit(train_features, train_labels,
                epochs=epochs,
                batch_size=batch_size, 
                validation_data=(validation_features, validation_labels))

path='\vgg'
tfjs.converters.save_keras_model(ricemodel, path)

TensorFlowjs代码是

 $(document).ready()
{
  $('.progress-bar').hide();
}
$("#image-selector").change(function(){
let reader = new FileReader();

reader.onload = function(){
    let dataURL = reader.result;
    $("#selected-image").attr("src",dataURL);
    $("#prediction-list").empty();
}
let file = $("#image-selector").prop('files')[0];
reader.readAsDataURL(file);
});


$("#model-selector").change(function(){
loadModel($("#model-selector").val());
$('.progress-bar').show();
})

let model;
async function loadModel(name){
model=await tf.loadModel(`http://localhost:8081/${name}/model.json`);
$('.progress-bar').hide();
}


$("#predict-button").click(async function(){
let image= $('#selected-image').get(0);
let tensor = preprocessImage(image,$("#model-selector").val());

let prediction = await model.predict(tensor).data();
let top5=Array.from(prediction)
            .map(function(p,i){
return {
    probability: p,
    className: IMAGENET_CLASSES[i]
};
}).sort(function(a,b){
    return b.probability-a.probability;
}).slice(0,5);

$("#prediction-list").empty();
top5.forEach(function(p){
$("#prediction- 
 list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
});

});


function preprocessImage(image,modelName)
{
 let tensor=tf.fromPixels(image)
.resizeNearestNeighbor([224,224])
.toFloat();//.sub(meanImageNetRGB)

 if(modelName==undefined)
 {
    return tensor.expandDims();
}
else if(modelName=="vgg")
{
    let meanImageNetRGB= tf.tensor1d([123.68,116.779,103.939]);
    return tensor.sub(meanImageNetRGB)
                .reverse(2)
                .expandDims();
}
else if(modelName=="mobilenet")
{
    let offset=tf.scalar(127.5);
    return tensor.sub(offset)
                .div(offset)
                .expandDims();
}
else
{
    throw new Error("UnKnown Model error");
}
}

在将模型加载到tensorflowjs中后在浏览器中出现错误消息,我得到以下错误(我可以在Web开发控制台中看到该错误消息)

  

未捕获(承诺)错误:检查时出错:预期flatten_1_input的形状为[null,7,7,512],但数组的形状为[1,224,224,3]

有什么方法可以解决该问题,可以在加载图像到浏览器之前重塑图像吗?

我尝试了所有可能的选项,但现在被困住了。

我已经检查了stackoverflow上可能的解决方案。如何在浏览器上运行分类模型?

2 个答案:

答案 0 :(得分:1)

问题是您正在创建一个全新的模型ricemodel,它看起来像一个完全连接的层,并将其另存为独立模型,而没有下面的卷积基础(在您的情况下为VGG)。这就是为什么模型的输入层的形状为[,7,7,512](特征向量)而不是[[224,224,3](原始图像数据)。

要解决此问题,您需要首先加载经过预先​​训练的权重的VGG模型(例如“ imagenet”),弹出顶层,然后在顶部添加ricemodel。最后,保存合并的新模型并将其导出到tfjs。

答案 1 :(得分:1)

@andyPotato的答案所指出,您还需要将特征提取器模型转换为js。

from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.models import save_model

#download model 
model = VGG16(weights='imagenet', 
          include_top=False,
          input_shape=(width, height, 3)) # tune parameters

#save the model 
save_model(
    model,
    "vgg.h5",
    overwrite=True,
)

将VGG特征提取器转换为js

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

现在在js中使用这两种模型进行推断

model1 = await tf.loadModel(`/url/of/vgg/converted/model.json`);

featureExtracted = await model1.predict(image)

model2 = await tf.loadModel(`/url/of/sequential/model/model.json`);

prediction = await model2.predict(featureExtracted)