我正在尝试对character_rnn(https://github.com/tensorflow/tensorflow/blob/671baf080238025da9698ea980cd9504005f727c/tensorflow/examples/learn/text_classification_character_rnn.py)进行文本分类。
如何为它写一个serving_input_fn?我要保存并恢复此模型
扩展了代码以保存但出现错误,请帮助
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
feature_spec = {"feature":tf.FixedLenFeature([100],tf.int64)}
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
然后
classifier.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_fn)
并收到此错误
TypeError:无法将类型的对象转换为Tensor。内容:{'feature':}。考虑将元素强制转换为受支持的类型。
请帮助我。
答案 0 :(得分:0)
serving_fn
要求功能必须为tensor
。
请在下面查看示例:
def serving_fn():
day_of_month = tf.Variable([], dtype=tf.int64, name='DAY_OF_MONTH')
day_of_week = tf.Variable([], dtype=tf.int64, name='DAY_OF_WEEK')
tail_num = tf.Variable([], dtype=tf.string,name='TAIL_NUM')
op_carrier_fl_num = tf.Variable([], dtype=tf.int64, name='OP_CARRIER_FL_NUM')
origin_airport_id = tf.Variable([], dtype=tf.int64, name='ORIGIN_AIRPORT_ID')
dest_airport_id = tf.Variable([], dtype=tf.int64, name='DEST_AIRPORT_ID')
dep_time_blk = tf.Variable([], dtype=tf.string,name='DEP_TIME_BLK')
reqd_inputs = {'DAY_OF_MONTH':day_of_month,
'DAY_OF_WEEK':day_of_week,
'TAIL_NUM':tail_num,
'OP_CARRIER_FL_NUM':op_carrier_fl_num,
'ORIGIN_AIRPORT_ID':origin_airport_id,
'DEST_AIRPORT_ID':dest_airport_id,
'DEP_TIME_BLK':dep_time_blk}
fn = tf.estimator.export.build_raw_serving_input_receiver_fn(reqd_inputs)
return fn
根据您的特征和数据类型,您需要将它们转换为相应的张量。
如果有帮助,可以在https://www.kaggle.com/jintolonappan/gbm-tf2-boostedtreesclassifier-export-to-serve
上以Kaggle笔记本的形式获得上述示例。