Tensorflow服务:如何迭代输入占位符

时间:2018-05-03 03:59:20

标签: python tensorflow rnn tensorflow-serving

我正在使用Tensorflow构建RNN(GRU)模型,现在模型已经过培训,我需要通过Tensorflow服务部署它来进行预测服务。
所以我创建了一个signature_def函数,如下所示,该函数需要将id作为输入,其长度不固定(所以我将形状设置为[None]),在函数内部,id中的id被逐一挑选出来GRU细胞。问题是,对于None形状,我无法弄清楚如何迭代所有id

def signature_def(self):
    ids = tf.placeholder(tf.int32, [None], name='input')

    state = [np.zeros([1, self.rnn_size], dtype=np.float32) for _ in range(self.layers)]

    for i in range(<length_of_ids>):
        id = [ids[i]]
        inputs = tf.nn.embedding_lookup(self.embedding, id)
        output, state = self.stacked_cell(inputs, tuple(state))
    logits = tf.matmul(output, self.softmax_W, transpose_b=True) + self.softmax_b
    outputs = self.final_activation(logits)

    tensor_info_x = tf.saved_model.utils.build_tensor_info(ids)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(outputs)

    return tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'ids': tensor_info_x},
            outputs={'preds': tensor_info_y},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

我曾尝试过tf.map_fn,它报告错误为&#34; ... / dropout / mul在while循环中#34; 我也尝试添加另一个输入参数来传递id长度,如下所示,但似乎长度参数不能改变for循环,它仍然是默认值3:

def signature_def(self):
    ids = tf.placeholder(tf.int32, [None], name='input')
    length = tf.placeholder_with_default([3], [1], name='length')

    state = [np.zeros([1, self.rnn_size], dtype=np.float32) for _ in range(self.layers)]

    for i in range(length.eval()[0]):
        id = [ids[i]]
        inputs = tf.nn.embedding_lookup(self.embedding, id)
        output, state = self.stacked_cell(inputs, tuple(state))
    logits = tf.matmul(output, self.softmax_W, transpose_b=True) + self.softmax_b
    outputs = self.final_activation(logits)

    tensor_info_x = tf.saved_model.utils.build_tensor_info(ids)
    tensor_info_l = tf.saved_model.utils.build_tensor_info(length)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(outputs)

    return tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'ids': tensor_info_x, 'length': tensor_info_l},
            outputs={'preds': tensor_info_y},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

非常感谢任何建议或指导

谢谢!

1 个答案:

答案 0 :(得分:0)

可能是我必须要解决的类似情况。见这里:tensorflow serving input function for sliding window over timeseries data。 此问题中的代码使用tf.while_loop迭代输入时间轴(滑动窗口)。但是,它遇到了与不匹配服务输入的示例数量相关的另一个问题。我还没有找到解决这个后续问题的方法。