我想在Tensorflow中从Keras训练pb和检查点。
我已经成功地将Keras模型转换为Tensorflow pb和检查点。 而且我已经成功推断了。 但是问题是,我不知道该怎么做。 这种Keras模型似乎没有培训内容,或者我只是不知道在培训中应该输入什么信息。
此代码将Keras模型转换为Tensorflow pb和检查点。
from keras import backend as K
from keras.models import load_model
import tensorflow as tf
model = load_model('model/my_model.h5')
K.set_learning_phase(0) #0 : test, 1 : train
sess = K.get_session()
saver = tf.train.Saver()
saver.save(sess, 'keras/keras.ckpt')
sess.graph.as_default()
graph = sess.graph
with open('keras/keras.pb', 'wb') as f:
f.write(graph.as_graph_def().SerializeToString())
这是读取pb和检查点的代码
def keras_model():
sess = tf.Session()
saver = tf.train.import_meta_graph('keras/keras.ckpt.meta')
saver.restore(sess, "keras/keras.ckpt")
sess.graph.as_default()
graph = tf.get_default_graph()
a = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
#print(a)
img = cv2.imread("data/wqds_backbead_0_3.png", cv2.IMREAD_COLOR)
img = img[...,::-1] # bgr to rgb
img = img.astype('float32')
img = np.expand_dims(img, axis=0)
INPUT1 = graph.get_tensor_by_name("input_1:0")
OUTPUT1 = graph.get_tensor_by_name("softmax/Softmax:0")
TARGET1 = graph.get_tensor_by_name("softmax_target:0")
print(TARGET1)
pred = sess.run(OUTPUT1, feed_dict={INPUT1: img})
print(pred, pred.shape, pred.dtype)