我有一个使用tensorflow.contrib
Python API在TensorFlow 1.3,Keras 2.0.6-tf上训练的模型。就像魅力一样。
但是当我在TensorFlow 1.4(或更高版本)环境中加载模型时,预测是恒定的,即不正确。没有任何错误消息。
我要做的是:
from tensorflow.contrib.keras.api.keras.models import load_model
model = load_model(..)
predictions = model.predict(input, batch_size=batch_size)
独立加载模型和权重,而不仅仅是加载模型.h5
文件没有任何作用。
这是一个已知问题吗?如果可以,是否有解决方法?
感谢您的帮助。
这是模型的h5 file。如果它有助于解决这个难题,请参考以下模型摘要:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_3 (InputLayer) (None, 40, 256, 1) 0
____________________________________________________________________________________________________
BN0 (BatchNormalization) (None, 40, 256, 1) 4 input_3[0][0]
____________________________________________________________________________________________________
Conv1 (Conv2D) (None, 40, 256, 16) 96 BN0[0][0]
____________________________________________________________________________________________________
BN1 (BatchNormalization) (None, 40, 256, 16) 64 Conv1[0][0]
____________________________________________________________________________________________________
Conv2 (Conv2D) (None, 40, 256, 16) 1296 BN1[0][0]
____________________________________________________________________________________________________
BN2 (BatchNormalization) (None, 40, 256, 16) 64 Conv2[0][0]
____________________________________________________________________________________________________
Conv3 (Conv2D) (None, 40, 256, 16) 1296 BN2[0][0]
____________________________________________________________________________________________________
average_pooling2d_9 (AveragePool (None, 8, 256, 16) 0 Conv3[0][0]
____________________________________________________________________________________________________
BN3 (BatchNormalization) (None, 8, 256, 16) 64 average_pooling2d_9[0][0]
____________________________________________________________________________________________________
Conv4.1 (Conv2D) (None, 8, 256, 24) 12312 BN3[0][0]
____________________________________________________________________________________________________
Conv4.2 (Conv2D) (None, 8, 256, 24) 24600 BN3[0][0]
____________________________________________________________________________________________________
Conv4.3 (Conv2D) (None, 8, 256, 24) 36888 BN3[0][0]
____________________________________________________________________________________________________
Conv4.4 (Conv2D) (None, 8, 256, 24) 49176 BN3[0][0]
____________________________________________________________________________________________________
Conv4.5 (Conv2D) (None, 8, 256, 24) 73752 BN3[0][0]
____________________________________________________________________________________________________
Conv4.6 (Conv2D) (None, 8, 256, 24) 98328 BN3[0][0]
____________________________________________________________________________________________________
Concat.Conv4 (Concatenate) (None, 8, 256, 144) 0 Conv4.1[0][0]
Conv4.2[0][0]
Conv4.3[0][0]
Conv4.4[0][0]
Conv4.5[0][0]
Conv4.6[0][0]
____________________________________________________________________________________________________
Conv4.1x1 (Conv2D) (None, 8, 256, 36) 5220 Concat.Conv4[0][0]
____________________________________________________________________________________________________
average_pooling2d_10 (AveragePoo (None, 4, 256, 36) 0 Conv4.1x1[0][0]
____________________________________________________________________________________________________
BN4 (BatchNormalization) (None, 4, 256, 36) 144 average_pooling2d_10[0][0]
____________________________________________________________________________________________________
Conv5.1 (Conv2D) (None, 4, 256, 24) 27672 BN4[0][0]
____________________________________________________________________________________________________
Conv5.2 (Conv2D) (None, 4, 256, 24) 55320 BN4[0][0]
____________________________________________________________________________________________________
Conv5.3 (Conv2D) (None, 4, 256, 24) 82968 BN4[0][0]
____________________________________________________________________________________________________
Conv5.4 (Conv2D) (None, 4, 256, 24) 110616 BN4[0][0]
____________________________________________________________________________________________________
Conv5.5 (Conv2D) (None, 4, 256, 24) 165912 BN4[0][0]
____________________________________________________________________________________________________
Conv5.6 (Conv2D) (None, 4, 256, 24) 221208 BN4[0][0]
____________________________________________________________________________________________________
Concat.Conv5 (Concatenate) (None, 4, 256, 144) 0 Conv5.1[0][0]
Conv5.2[0][0]
Conv5.3[0][0]
Conv5.4[0][0]
Conv5.5[0][0]
Conv5.6[0][0]
____________________________________________________________________________________________________
Conv5.1x1 (Conv2D) (None, 4, 256, 36) 5220 Concat.Conv5[0][0]
____________________________________________________________________________________________________
average_pooling2d_11 (AveragePoo (None, 2, 256, 36) 0 Conv5.1x1[0][0]
____________________________________________________________________________________________________
BN5 (BatchNormalization) (None, 2, 256, 36) 144 average_pooling2d_11[0][0]
____________________________________________________________________________________________________
Conv6.1 (Conv2D) (None, 2, 256, 24) 27672 BN5[0][0]
____________________________________________________________________________________________________
Conv6.2 (Conv2D) (None, 2, 256, 24) 55320 BN5[0][0]
____________________________________________________________________________________________________
Conv6.3 (Conv2D) (None, 2, 256, 24) 82968 BN5[0][0]
____________________________________________________________________________________________________
Conv6.4 (Conv2D) (None, 2, 256, 24) 110616 BN5[0][0]
____________________________________________________________________________________________________
Conv6.5 (Conv2D) (None, 2, 256, 24) 165912 BN5[0][0]
____________________________________________________________________________________________________
Conv6.6 (Conv2D) (None, 2, 256, 24) 221208 BN5[0][0]
____________________________________________________________________________________________________
Concat.Conv6 (Concatenate) (None, 2, 256, 144) 0 Conv6.1[0][0]
Conv6.2[0][0]
Conv6.3[0][0]
Conv6.4[0][0]
Conv6.5[0][0]
Conv6.6[0][0]
____________________________________________________________________________________________________
Conv6.1x1 (Conv2D) (None, 2, 256, 36) 5220 Concat.Conv6[0][0]
____________________________________________________________________________________________________
average_pooling2d_12 (AveragePoo (None, 1, 256, 36) 0 Conv6.1x1[0][0]
____________________________________________________________________________________________________
BN6 (BatchNormalization) (None, 1, 256, 36) 144 average_pooling2d_12[0][0]
____________________________________________________________________________________________________
Conv7.1 (Conv2D) (None, 1, 256, 24) 27672 BN6[0][0]
____________________________________________________________________________________________________
Conv7.2 (Conv2D) (None, 1, 256, 24) 55320 BN6[0][0]
____________________________________________________________________________________________________
Conv7.3 (Conv2D) (None, 1, 256, 24) 82968 BN6[0][0]
____________________________________________________________________________________________________
Conv7.4 (Conv2D) (None, 1, 256, 24) 110616 BN6[0][0]
____________________________________________________________________________________________________
Conv7.5 (Conv2D) (None, 1, 256, 24) 165912 BN6[0][0]
____________________________________________________________________________________________________
Conv7.6 (Conv2D) (None, 1, 256, 24) 221208 BN6[0][0]
____________________________________________________________________________________________________
Concat.Conv7 (Concatenate) (None, 1, 256, 144) 0 Conv7.1[0][0]
Conv7.2[0][0]
Conv7.3[0][0]
Conv7.4[0][0]
Conv7.5[0][0]
Conv7.6[0][0]
____________________________________________________________________________________________________
Conv7.1x1 (Conv2D) (None, 1, 256, 36) 5220 Concat.Conv7[0][0]
____________________________________________________________________________________________________
BN7 (BatchNormalization) (None, 1, 256, 36) 144 Conv7.1x1[0][0]
____________________________________________________________________________________________________
flatten_3 (Flatten) (None, 9216) 0 BN7[0][0]
____________________________________________________________________________________________________
dropout_3 (Dropout) (None, 9216) 0 flatten_3[0][0]
____________________________________________________________________________________________________
dense_7 (Dense) (None, 64) 589888 dropout_3[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 64) 256 dense_7[0][0]
____________________________________________________________________________________________________
dense_8 (Dense) (None, 64) 4160 batch_normalization_5[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 64) 256 dense_8[0][0]
____________________________________________________________________________________________________
dense_9 (Dense) (None, 256) 16640 batch_normalization_6[0][0]
====================================================================================================
Total params: 2,921,684
Trainable params: 2,921,042
Non-trainable params: 642
答案 0 :(得分:0)
这是我最终使Keras / TF 1.3模型与Keras / TF> 1.3一起使用的方式:
import tensorflow as tf
from tensorflow.contrib.keras.python.keras import backend
from tensorflow.contrib.keras.python.keras.models import load_model
name = 'my_model_name'
model = load_model('{}.h5'.format(name))
# save state using TensorFlow
saver = tf.train.Saver()
saver.save(backend.get_session(), '{}_weights.tf'.format(name))
backend.clear_session()
import tensorflow as tf
from tensorflow.python.keras import backend # <- different import!
from tensorflow.contrib.keras.api.keras.models import load_model
name = 'my_model_name'
# first load model architecture
model = load_model('{}.h5'.format(name))
# then load correct state using TensorFlow
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess = backend.get_session()
sess.run(tf.variables_initializer(all_variables))
# create a list of variables that does not include the state of
# the used Adam optimizer (it's missing in the .h5 file).
# however, I believe THAT WAS NOT THE ISSUE.
var_list = [v for v in all_variables if "Adam" not in v.name]
saver = tf.train.Saver(var_list=var_list)
saver.restore(sess, '{}_weights.tf'.format(name))
# now save the whole model again using Keras (this time the correct way)
model.save('{}_new.h5'.format(name))
解决方法基于this post。显然,Keras(不是TensorFlow)如何还原已保存模型的状态。