无法转换为tensor proto:发送输入文件

时间:2018-04-02 09:56:25

标签: python tensorflow protocol-buffers tensorflow-serving werkzeug

1)我编写了一个简单的程序,使用张量流来读取文本文件,并希望使用tensorflow服务部署在服务器中。这是程序

tf.app.flags.DEFINE_integer('model_version', 2, 'version number of the model.')
tf.app.flags.DEFINE_string('work_dir', '', 'Working directory.')
FLAGS = tf.app.flags.FLAGS

sess = tf.InteractiveSession()
# define the tensorflow network and do some trains
x = tf.placeholder("string", name="x")

sess.run(tf.global_variables_initializer())
y = tf.read_file(x, name="y")


export_path_base = FLAGS.work_dir
export_path = os.path.join(tf.compat.as_bytes(export_path_base),
  tf.compat.as_bytes(str(FLAGS.model_version)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
  inputs={'input': tensor_info_x},
  outputs={'output': tensor_info_y},
  method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
  'prediction':
  prediction_signature,
  },
  legacy_init_op=legacy_init_op)

builder.save()

2)我创建了这个模型的protobuf并使其在服务器上运行。现在我编写了一个客户端程序来发送输入文本文件并生成输出。这是一个简单的客户端文件来读取它

tf.app.flags.DEFINE_string('server', 'localhost:9000', 'PredictionService host:port')
tf.app.flags.DEFINE_string('input','','input for the model')
FLAGS = tf.app.flags.FLAGS

def do_inference(hostport,no):

  # create connection
  host, port = hostport.split(':')
  channel = implementations.insecure_channel(host, int(port))
  stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

  # initialize a request
  data = no
  request = predict_pb2.PredictRequest()
  request.model_spec.name = 'modelX'
  request.model_spec.signature_name = 'prediction'

  request.inputs['input'].CopyFrom(tf.contrib.util.make_tensor_proto(data))
  # predict
  result = stub.Predict(request, 5.0) # 5 seconds
  return result

def main(_):
    result = do_inference(FLAGS.server,FLAGS.input)
    print('Result is: ', result)


if __name__ == '__main__':
  tf.app.run()

所以当我运行此代码时,

  

python client.py --server = 172.17.0.2:9000 --input = hello.txt

它产生输出

*Hello!*

3)现在我用flask框架编写了一个客户端文件来创建REST api:

tf.app.flags.DEFINE_string('server', 'localhost:9000', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS

app = Flask(__name__)

class mainSessRunning():

    def __init__(self):
        host, port = FLAGS.server.split(':')
        channel = implementations.insecure_channel(host, int(port))
        self.stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

        self.request = predict_pb2.PredictRequest()
        self.request.model_spec.name = 'modelX'
        self.request.model_spec.signature_name = 'prediction'

    def inference(self, val_x):
        data = val_x
        self.request.inputs['input'].CopyFrom(tf.contrib.util.make_tensor_proto(data))
        result = self.stub.Predict(self.request, 5.0)
        return result

run = mainSessRunning()

# Define a route for the default URL, which loads the form
@app.route('/pred', methods=['POST'])
def pred():
    request_data = request.files['file']
    result = run.inference(request_data)
    rs = json_format.MessageToJson(result)
    return jsonify({'result':rs})

使用邮递员,当我提供相同的输入文件' hello.txt'时,它会抛出错误:

  

TypeError:无法转换类型的对象   (等级' werkzeug.datastructures.File.Storage')张量。内容:   (Filestorage:u' hello.txt'(' text / plain'))。考虑铸造元素   到支持的类型

我已经发布了here。它与普通的client.py文件一起工作正常,但是它没有使用flask框架的client.py。我在tensorflow中关注了this官方文档。并发现 make_tensor_proto接受"值" python标量,python列表,numpy ndarray或numpy标量

所以我的问题是我如何发送这个werkzeug文件存储 它将被接受为张量原型?或者这是一个错误?

1 个答案:

答案 0 :(得分:0)

查看此网址http://werkzeug.pocoo.org/docs/0.14/datastructures/

如果我看一下" def推理",本地"数据"变量将保存对类型为#34的对象的引用; werkzeug.datastructures.FileStorage"

当你从帖子[通过烧瓶]获得一些文件时,这个文件实际上被包装到一个对象" werkzeug.datastructures.FileStorage"所以" request.file"不是文件而是类型为#34的对象; werkzeug.datastructures.FileStorage"你应该找到一种方法来跟踪底层文件。

如果您查看提供者网址,我们可以执行此操作:

request_data = request.files['file']
request_data.save(destination_file_path)
#adapt your "inference" to get a file path instead
result = run.inference(destination_file_path)
推理中的

def inference(destination_file_path):
    with open(destination_file_path) as f:
         #here handle "f" content like you want...hope that it will help you