Tensorflow服务模型为不同的输入返回相同的输出

时间:2017-11-01 02:54:56

标签: tensorflow deep-learning tensorflow-serving

我最近为tensorflow服务导出了文本摘要模型。我能够运行一个客户端并成功地对模型进行gRPC调用,但不管我作为输入发送什么,我每次都得到相同的输出。我假设它与我导出模型的方式有关,也许它使用静态条目而不是我发送的输入,但是我不确定我执行的错误。我已经验证了客户端文件中传递的数据是正确的,但无论发送到self.stub.Predict的是什么,它总是相同的输出。有人看过这个或者可能看到我做错了吗?

客户端文件摘要:

def decode(self, passed_data):
    now = datetime.now()

    self.request.inputs['inputs'].CopyFrom(
        #tf.contrib.util.make_tensor_proto(test_data_set[0], shape=[1]))
        tf.contrib.util.make_tensor_proto(passed_data, shape=[1]))
    result = self.stub.Predict(self.request, 5.0)
    waiting = datetime.now() - now
    return result, waiting.microseconds

导出代码段:     ....

decode_mdl_hps = hps._replace(dec_timesteps=1)
            model = seq2seq_attention_model.Seq2SeqAttentionModel(
                decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
            decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)

            serialized_output = tf.placeholder(tf.string, name='tf_output')


            serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
            feature_configs = {
                'inputs': tf.FixedLenFeature(shape=[1], dtype=tf.string),
            }
            tf_example = tf.parse_example(serialized_tf_example, feature_configs)

            saver = tf.train.Saver()#sharded=True)

            config = tf.ConfigProto(allow_soft_placement = True)

            with tf.Session(config = config) as sess:

                # Restore variables from training checkpoints.
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    print('Successfully loaded model from %s at step=%s.' %
                        (ckpt.model_checkpoint_path, global_step))
                    res = decoder._Decode(saver, sess)
                    res_tensor = tf.convert_to_tensor(res)

                    print("Decoder value {}".format(type(res)))
                else:
                    print('No checkpoint file found at %s' % FLAGS.checkpoint_dir)
                    return

                # Export model
                export_path = os.path.join(FLAGS.export_dir,str(FLAGS.export_version))
                print('Exporting trained model to %s' % export_path)


                #-------------------------------------------

                tensor_info_inputs = tf.saved_model.utils.build_tensor_info(serialized_tf_example)
                tensor_info_outputs = tf.saved_model.utils.build_tensor_info(res_tensor)

                prediction_signature = (
                    tf.saved_model.signature_def_utils.build_signature_def(
                        inputs={ tf.saved_model.signature_constants.PREDICT_INPUTS: tensor_info_inputs},
                        outputs={tf.saved_model.signature_constants.PREDICT_OUTPUTS:tensor_info_outputs},
                        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
                        ))

                #----------------------------------

                legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
                builder = saved_model_builder.SavedModelBuilder(export_path)

                builder.add_meta_graph_and_variables(
                    sess=sess, 
                    tags=[tf.saved_model.tag_constants.SERVING],
                    signature_def_map={
                        'predict':prediction_signature,
                    },
                    legacy_init_op=legacy_init_op)
                builder.save()

                print('Successfully exported model to %s' % export_path)

0 个答案:

没有答案