我正在使用Tensorflow构建RNN(GRU)模型,现在模型已经过培训,我需要通过Tensorflow服务部署它来进行预测服务。
所以我创建了一个signature_def函数,如下所示,该函数需要将id作为输入,其长度不固定(所以我将形状设置为[None]),在函数内部,id中的id被逐一挑选出来GRU细胞。问题是,对于None形状,我无法弄清楚如何迭代所有id
def signature_def(self):
ids = tf.placeholder(tf.int32, [None], name='input')
state = [np.zeros([1, self.rnn_size], dtype=np.float32) for _ in range(self.layers)]
for i in range(<length_of_ids>):
id = [ids[i]]
inputs = tf.nn.embedding_lookup(self.embedding, id)
output, state = self.stacked_cell(inputs, tuple(state))
logits = tf.matmul(output, self.softmax_W, transpose_b=True) + self.softmax_b
outputs = self.final_activation(logits)
tensor_info_x = tf.saved_model.utils.build_tensor_info(ids)
tensor_info_y = tf.saved_model.utils.build_tensor_info(outputs)
return tf.saved_model.signature_def_utils.build_signature_def(
inputs={'ids': tensor_info_x},
outputs={'preds': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
我曾尝试过tf.map_fn,它报告错误为&#34; ... / dropout / mul在while循环中#34; 我也尝试添加另一个输入参数来传递id长度,如下所示,但似乎长度参数不能改变for循环,它仍然是默认值3:
def signature_def(self):
ids = tf.placeholder(tf.int32, [None], name='input')
length = tf.placeholder_with_default([3], [1], name='length')
state = [np.zeros([1, self.rnn_size], dtype=np.float32) for _ in range(self.layers)]
for i in range(length.eval()[0]):
id = [ids[i]]
inputs = tf.nn.embedding_lookup(self.embedding, id)
output, state = self.stacked_cell(inputs, tuple(state))
logits = tf.matmul(output, self.softmax_W, transpose_b=True) + self.softmax_b
outputs = self.final_activation(logits)
tensor_info_x = tf.saved_model.utils.build_tensor_info(ids)
tensor_info_l = tf.saved_model.utils.build_tensor_info(length)
tensor_info_y = tf.saved_model.utils.build_tensor_info(outputs)
return tf.saved_model.signature_def_utils.build_signature_def(
inputs={'ids': tensor_info_x, 'length': tensor_info_l},
outputs={'preds': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
非常感谢任何建议或指导
谢谢!
答案 0 :(得分:0)
可能是我必须要解决的类似情况。见这里:tensorflow serving input function for sliding window over timeseries data。 此问题中的代码使用tf.while_loop迭代输入时间轴(滑动窗口)。但是,它遇到了与不匹配服务输入的示例数量相关的另一个问题。我还没有找到解决这个后续问题的方法。