来自tf.session.run的网络输出与使用keras.Model.predict获得的输出大不相同

时间:2017-11-23 15:52:11

标签: python tensorflow keras

我正在尝试通过Tensorflow会话使用Keras模型。但结果与model.predictsess.run不同。有没有办法通过Tensorflow会话使用Kers模型?

  

Tensorflow版本:1.4.0
  Keras版本:2.1.1

from sklearn.datasets.samples_generator import make_circles
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
import numpy as np
import tensorflow as tf
from keras import backend as K

sess = tf.Session()
K.tensorflow_backend.set_session(sess)

X, y = make_circles(n_samples=1000,
                    noise=0.1,
                    factor=0.2,
                    random_state=0)

model = Sequential()
model.add(Dense(4, input_shape=(2,), activation='tanh'))
model.add(Dense(1, activation='sigmoid'))
model.compile(SGD(lr=0.5), 'binary_crossentropy', metrics=['accuracy'])
model.fit(X, y, epochs=20)

print("Keras model. First prediction: " + str(model.predict(np.c_[0, 0])))
print("Keras model. Second prediction: " + str(model.predict(np.c_[1.5, 1.5])))

with sess.as_default():

    y_tensor = model.outputs[0]
    x_tensor = model.inputs[0]
    sess.run(tf.global_variables_initializer())

    print("TF model. First prediction: " + str(sess.run(y_tensor, feed_dict={x_tensor: np.c_[0, 0]} )))
    print("TF model. Second prediction: " + str(sess.run(y_tensor, feed_dict={x_tensor: np.c_[1.5, 1.5]} )))

1 个答案:

答案 0 :(得分:6)

好的,它是K.set_session(s)而不是K.tensorflow_backend.set_session(s)

第二:sess.run(tf.global_variables_initializer())使用各自的初始化程序重置所有变量,包括网络权重(默认使用xavier初始化程序)。

所以你是:

  
      
  1. 训练keras模型
  2.   
  3. 打印keras模型的预测
  4.   
  5. 重新设定为随机权重
  6.   
  7. 打印相同型号的预测
  8.   

评论sess.run(tf.global_variables_initializer())可解决问题:

Keras model. First prediction: [[ 0.99195099]]
Keras model. Second prediction: [[ 0.03110269]]
TF model. First prediction: [[ 0.99195099]]
TF model. Second prediction: [[ 0.03110269]]