我尝试建立一个tensorflow模型-在其中我将带有另一个模型的pickle文件加载为tensorflow模型的一部分。该代码分为两部分,在其中创建模型(保存)并使用模型进行预测(加载)。我收到ValueError:找不到回调pyfunc_0
.pb文件本身很小,因此看起来它没有将模型存储在.pb文件内的.pickle文件中。我不确定该怎么办。
保存部分
import tensorflow as tf
from keras import backend as K
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants, signature_constants, signature_def_utils_impl
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.optimizers import SGD
import numpy as np
import pickle
model_version = "465555564"
epoch = 100
tensorboard = TensorBoard(log_dir='./logs', histogram_freq = 0, write_graph = True, write_images = False)
sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)
def my_func(x):
with open(PATH_TO_PICKLE, "rb") as f:
loadCF = pickle.load(f)
return np.float32(loadCF.predict([x])[1])
input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)
prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": input}, {"prediction": y})
builder = saved_model_builder.SavedModelBuilder('./'+model_version)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
加载部分
sess=tf.Session()
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'inputs'
output_key = 'prediction'
export_path = './465555564/'
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
export_path)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
y_out = sess.run(y, {x: [0.0, 3.0,2.0,1.0,1.0,0.0,1.0,3.0,1.0,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
1.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,
0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0281021,
1.1674791,0.0772629,1.00919452640745377359,-0.40733408431212109191,0.27344889607694411460,-0.27692477736208176431,
0.90979100598229301067,0.30854060293899643330,-0.89088669667641318117,0.71015013257662451540,-0.45934534155660206034,
-1.5771756172180175781,-0.44342430101500618367,0.99046792752212953204,0.77406677189800476846,0.22008506072840341994,
-0.31012541014287209329,-0.30062459437047234223,-0.02684695402988129115,0.17956349253654479980,
-0.46235901945167118265,0.42958878223887747572,-0.44371617585420608521,-0.84945221741994225706,
0.63907705081833732219,-0.70754766008920144671,0.48411194566223358926,-0.12378847102324168350,
0.15848264263735878377]})
print(y_out)
答案 0 :(得分:0)
tf.py_func不支持以pb格式保存,请改用检查点格式