来自Keras的冷冻模型在恢复后没有预测

时间:2017-03-17 11:53:03

标签: android tensorflow keras

我使用Keras来构建和训练我的模型。模型看起来像这样:

inputs = Input(shape=(input_size, 3), dtype='float32', name='input')
lstm1 = LSTM(128, return_sequences=True)(inputs)
dropout1 = Dropout(0.5)(lstm1)
lstm2 = LSTM(128)(dropout1)
dropout2 = Dropout(0.5)(lstm2)
outputs = Dense(output_size, activation='softmax', name='output')(dropout2)

在制作检查点之前,我的模型可以很好地预测类(softmax之后的类分布):

[[ 0.00117011  0.00631532  0.10080294  0.84386677  0.04784485]]

但是在下一个代码之后:

all_saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
print save_path + '/model_predeploy.chkp'
all_saver.save(sess, save_path + '/model_predeploy.chkp', meta_graph_suffix='meta', write_meta_graph=True)
tf.train.write_graph(sess.graph_def, save_path, "model.pb", False)

使用

冻结它
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/model.pb --input_checkpoint=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/model_predeploy.chkp --output_graph=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/output.pb --output_node_names=Softmax --input_binary=true

之后加载它
graph = load_graph(args.frozen_model_filename)

    # We can verify that we can access the list of operations in the graph
    for op in graph.get_operations():
        print(op.name)

    # We access the input and output nodes
    x = graph.get_tensor_by_name('input:0')
    y = graph.get_tensor_by_name('Softmax:0')

    data = [7.4768066E-4,-0.02217102,0.07727051,7.4768066E-4,-0.02217102,0.07727051,7.4768066E-4,-0.02217102,0.07727051,0.004989624,-0.020874023,0.09140015,0.004989624,-0.020874023,0.09140015,0.010604858,-0.010665894,0.025527954,0.010299683,0.018035889,-0.052749634,-0.012786865,0.017837524,-0.020828247,-0.045898438,0.007095337,0.01550293,-0.06680298,0.013702393,0.02687073,-0.061767578,0.026550291,-1.373291E-4,-0.036621094,0.041778564,-0.011276245,-0.042678833,0.054336548,0.036697388,-0.07182312,0.036483765,0.081726074,-0.08639526,0.041793823,0.07392883,-0.051788326,0.07649231,0.092178345,-0.056396484,0.0771637,0.11044311,-0.08444214,0.06201172,0.0920105,-0.12609863,0.06137085,0.104537964,-0.14356995,0.079071045,0.11187744,-0.17756653,0.08576965,0.16818237,-0.2379303,0.07879639,0.19819641,-0.2631073,0.13290405,0.19137573,-0.23666382,0.21955872,0.16033936,-0.23666382,0.21955872,0.16033936,-0.22547913,0.23838806,0.27246094,-0.26376343,0.19580078,0.33566284,-0.26376343,0.19580078,0.33566284,-0.4733429,0.19911194,-0.0050811768,-0.48905945,0.14544678,-0.21205139,-0.48905945,0.14544678,-0.21205139,-0.37893677,0.15655518,-0.1382904,-0.27426147,0.16381836,-0.052841187,-0.21949767,0.18780518,-0.045913696,-0.28207397,0.17993164,-0.1550293,-0.37120056,0.13322449,-0.4617462,-0.3773346,0.17321777,-0.7678375,-0.20349121,0.12588501,-0.7908478,-4.8828125E-4,0.116516106,-0.57121277,-0.090042114,0.08895874,-0.3849945,-0.232193,-0.028884886,-0.4724579,-0.19163513,-0.06340027,-0.5598297,-0.068481445,-0.025268555,-0.54397583,-0.03288269,-0.12750244,-0.48367307,0.0057525635,-0.030532837,-0.45234683,0.099868774,-0.0070648193,-0.57225037,0.21514893,0.05860901,-0.5052185,0.3602295,0.14176941,-0.4087372,0.57940674,0.16700745,-0.35438538,0.75743103,0.2631073,-0.5294647,0.75743103,0.2631073,-0.5294647,0.74624634,0.2193451,-0.70674133,0.91960144,0.29077148,-0.7026367,0.91960144,0.29077148,-0.7026367,0.81611633,0.34953308,-0.50927734,0.8429718,0.41278076,-0.38298035,0.84576416,0.4597778,-0.15159607,0.9177856,0.47735596,0.099731445,0.9820862,0.57232666,0.20970154,0.9269562,0.5357971,0.45666504,0.7898865,0.48097226,0.5698242,0.5332794,0.4213867,0.6626892,0.5032654,0.4464111,0.59614563,0.5827484,0.4588318,0.8383636,0.60975647,0.46882626,1.050766,0.58917236,0.52201843,0.9510345,0.48217773,0.502121,0.8063202,0.24050903,0.42752075,0.81951904,0.10655212,0.43006897,0.7798157,0.15496826,0.5040283,0.7533417,0.18733215,0.55770874,0.63716125,0.22062683,0.5880585,0.503067,0.06762695,0.49337766,0.6584778,-0.14086914,0.4414215,0.615036,-0.14086914,0.4414215,0.615036,-0.03614807,0.6751251,0.06636047,-0.03614807,0.6751251,0.06636047,0.17774963,0.741272,-0.09466553,0.21842958,0.7971039,-0.050811768,0.06843567,0.7729645,-0.34933472,-0.2092285,0.5443878,-0.5428009,-0.43028256,0.37249756,-0.5168762,-0.23457338,0.3491211,-0.45985416,0.15863037,0.49960327,-0.5370636,0.31782532,0.5680084,-0.8007355,0.1651001,0.5300598,-0.87919617,-0.086135864,0.49140927,-0.6066437,-0.20877077,0.4261017,-0.55911255,-0.33840942,0.34194946,-0.7007904,-0.36250305,0.27163696,-0.76208496,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
    data = np.reshape(data, [1, 142, 3])

    # We launch a Session
    with tf.Session(graph=graph) as sess:
        # Note: we didn't initialize/restore anything, everything is stored in the graph_def
        y_out = sess.run(y, feed_dict={
            x: data
        })
        print(y_out)

让我在所有标签上均匀分布:

[[ 0.20328824  0.19835895  0.19692752  0.20159255  0.19983278]]

我做错了吗?使用tensorflow 0.12,因为我无法在android上运行1.0。这是另一个故事。所有内容都是使用0.12构建,培训和导出的。

1 个答案:

答案 0 :(得分:6)

all_saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
print save_path + '/model_predeploy.chkp'
all_saver.save(sess, save_path + '/model_predeploy.chkp', meta_graph_suffix='meta', write_meta_graph=True)
tf.train.write_graph(sess.graph_def, save_path, "model.pb", False)

在第2行中,您从头开始重新初始化所有变量(不仅是未初始化的变量)。这意味着您的训练模型在那时消失了,您保存的模型只是随机/恒定权重(取决于您的初始化器)。

演示脚本:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

var = tf.get_variable('demo', dtype=tf.float32, shape=[], 
                      initializer=tf.zeros_initializer())

sess = tf.Session()

sess.run(tf.assign(var, 42));

print(var.eval(session=sess))

这打印42。

sess.run(tf.global_variables_initializer())

print(var.eval(session=sess))

这会打印0,因为变量已重新初始化为0。

因此,在训练模型之前初始化变量,并且在写出模型之前不要重新初始化它们。