保存一个Tensorflow模型并将其加载到Tensorflow.js中

时间:2018-10-08 05:20:06

标签: python tensorflow conv-neural-network tensorflow.js

在我的常规python代码中,我实现了CNN。我使用model.save保存它,并生成了四个文件(检查点,元数据,索引和其他一些文件)。但是,我无法将这四个文件直接加载到tensorflow.js。这是示例CNN:

import tflearn
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression

convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 1], name='input')

convnet = conv_2d(convnet, FIRST_NUM_CHANNEL, FILTER_SIZE, activation='relu')
convnet = max_pool_2d(convnet, 2)

convnet = conv_2d(convnet, FIRST_NUM_CHANNEL*2, FILTER_SIZE, activation='relu')
convnet = max_pool_2d(convnet, 2)

convnet = conv_2d(convnet, FIRST_NUM_CHANNEL*4, FILTER_SIZE, activation='relu')
convnet = max_pool_2d(convnet, 2)

convnet = fully_connected(convnet, FIRST_NUM_CHANNEL*8, activation='relu')
convnet = dropout(convnet, 0.7)

convnet = fully_connected(convnet, NUM_OUTPUT, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')

model = tflearn.DNN(convnet, tensorboard_dir='log')

train = train_data[:7000]
test = train_data[-1000:]


X = np.array([i[0] for i in train]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
Y = [i[1] for i in train]

test_x = np.array([i[0] for i in test]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
test_y = [i[1] for i in test]


model.fit({'input': X}, {'targets': Y}, n_epoch=NUM_EPOCHS, validation_set=({'input': test_x}, {'targets': test_y}), 
    snapshot_step=500, show_metric=True, run_id=MODEL_NAME)

model.save(MODEL_NAME)
print('MODEL SAVED:', MODEL_NAME)

此代码段的最后两行用于保存模型。我可以将模型加载到flask应用程序中,但我想将其移植到tensorflow.js。谁能给我有关如何执行此操作的教程?

1 个答案:

答案 0 :(得分:0)

tensorflowjs_converted在其他文件中输出权重文件weights_manifest.json和模型拓扑文件tensorflowjs_model.pb。要将模型加载到tensorflow.js中,请遵循以下步骤。

  • 使用服务器为包含文件的文件夹提供服务
// cd to the directory containing the files

// then launch the python server
python3 -m http-server

// or install and launch npm module http-server
npm install -g http-server
http-server --cors -c1 .
  • 创建一个js脚本以加载模型
(async () => {
   ...
   const model = await tf.loadFrozenModel('http://localhost:8080/tensorflowjs_model.pb', 'http://localhost:8080/weights_manifest.json')
 })()

loadModelloadFrozenModel之间存在区别。

  • loadModel用于加载在本地保存的模型。可以从浏览器的indexDB或从localStorage检索模型。 也许它可以被用来检索由另一个不同于Js的tensorflow API保存的模型,但是用户将被要求使用tf.io.browserFiles选择文件(我没有尝试过)

  • loadFrozenModel用于加载服务器提供的模型