无法使用Tensorflow.js在浏览器中加载保存的模型

时间:2019-08-18 11:07:05

标签: python-3.x deep-learning tensorflow.js transfer-learning tensorflowjs-converter

我正在尝试通过边缘设备上的转移学习来构建水稻分类器,我在https://github.com/ADLsourceCode/TensorflowJS的指导下获得了帮助

我的示例数据位于https://www.dropbox.com/s/esirpr6q1lsdsms/ricetransfer1.zip?dl=0

我使用下面提到的大米分类代码在本地保存了模型,并保存在TensorflowJS / Mobilenet_VGG16_Keras_To_TensorflowJS / static /  连同vgg和mobilenet,但是,我无法在浏览器中的tensorflowjs上加载Rice模型。

如果我尝试将vgg模型保存到本地系统中并在tensoflowjs(在浏览器中)中加载模型,则效果很好。

# Base variables
import os
base_dir = 'ricetransfer1/'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
test_dir = os.path.join(base_dir, 'test')
train_cats_dir = os.path.join(train_dir, 'KN')
train_dogs_dir = os.path.join(train_dir, 'DM')

train_size, validation_size, test_size = 90, 28, 26
#train_size, validation_size, test_size = 20, 23, 14

img_width, img_height = 224, 224  # Default input size for VGG16

# Instantiate convolutional base
from keras.applications import VGG16
import tensorflowjs as tfjs
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

img_width, img_height = 224, 224  # Default input size for VGG16

conv_base = VGG16(weights='imagenet', 
              include_top=False,
              input_shape=(img_width, img_height, 3))  
# 3 = number of channels in RGB pictures

 #saving the vgg model to run it locally
 tfjs.converters.save_keras_model(conv_base, '/TensorflowJS/Mobilenet_VGG16_Keras_To_TensorflowJS/static/vgg')

# Check architecture
conv_base.summary()


# Extract features
import os, shutil
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
train_size, validation_size, test_size = 90, 28, 25

datagen = ImageDataGenerator(rescale=1./255)
batch_size = 1
#train_dir = "ricetransfer1/train"
#validation_dir = "ricetransfer1/validation"
#test_dir="ricetransfer1/test"
#indices = np.random.choice(range(len(X_train)))

def extract_features(directory, sample_count):
#sample_count= X_train.ravel()

features = np.zeros(shape=(sample_count, 7, 7, 512))  # Must be equal to the output of the convolutional base
labels = np.zeros(shape=(sample_count))
# Preprocess data
generator = datagen.flow_from_directory(directory,
                                        target_size=(img_width,img_height),
                                        batch_size = batch_size,
                                        class_mode='binary')
# Pass data through convolutional base
i = 0
for inputs_batch, labels_batch in generator:
    features_batch = conv_base.predict(inputs_batch)
    features[i * batch_size: (i + 1) * batch_size] = features_batch
    labels[i * batch_size: (i + 1) * batch_size] = labels_batch
    i += 1
    if i * batch_size >= sample_count:
        break
return features, labels

train_features, train_labels = extract_features(train_dir, train_size)  # Agree with our small dataset size
validation_features, validation_labels = extract_features(validation_dir, validation_size)
 test_features, test_labels = extract_features(test_dir, test_size)



# Define model
from keras import models
from keras import layers
from keras import optimizers

epochs = 2

ricemodel = models.Sequential()
ricemodel.add(layers.Flatten(input_shape=(7,7,512)))
ricemodel.add(layers.Dense(256, activation='relu', input_dim=(7*7*512)))
ricemodel.add(layers.Dropout(0.5))
ricemodel.add(layers.Dense(1, activation='sigmoid'))
ricemodel.summary()

 # Compile model
 ricemodel.compile(optimizer=optimizers.Adam(),
          loss='binary_crossentropy',
          metrics=['acc'])


# Train model
import os
history = ricemodel.fit(train_features, train_labels,
                epochs=epochs,
                batch_size=batch_size, 
                validation_data=(validation_features, validation_labels))


##saving the rice classification model to run it locally
 tfjs.converters.save_keras_model(ricemodel, '/TensorflowJS/Mobilenet_VGG16_Keras_To_TensorflowJS/static/rice/')

我认为大米模型存在一些错误,我该如何解决这个问题?

预期输出是使用tensorflowjs在浏览器上运行大米分类

1 个答案:

答案 0 :(得分:1)

我认为这可能是由于tfjs文件的较旧版本引起的错误。

将最新版本更新为

<script  src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.5"></script> 

在您的html页面中,但是由于图像大小不同,可能会出现新的错误。

在这种情况下,我建议在浏览器中打开开发模式以查看确切的错误。