在Kivy应用程序中保持Tensorflow会话打开

时间:2018-03-16 15:42:40

标签: android python tensorflow kivy python-multithreading

我正在尝试运行在Kivy中制作的应用程序以及Tensorflow会话,并且每次进行预测时都不会加载它。更确切地说,我想知道如何在会话中调用该函数。

以下是会话的代码:

Private Sub Worksheet_Change(ByVal Target As Range)
    Dim r As Range, d As Range, e As Range
    Dim high As String, same As String, s As String

    If Intersect(Target, Range("D:E")) Is Nothing Then Exit Sub
    For Each r In Target
        Set d = Intersect(Range("D:D"), r.EntireRow)
        Set e = Intersect(Range("E:E"), r.EntireRow)
        If d.Value > e.Value And e.Value <> vbNullString Then high = high & ", " & r.Address
        If d.Value = e.Value And e.Value <> vbNullString Then same = same & ", " & r.Address
    Next r
    If high <> vbNullString Then s = "Discount too high in cells: " & Mid(high, 3)
    If same <> vbNullString Then
        If s <> vbNullString Then s = s & vbCrLf
        s = s & "Discount the same in cells: " & Mid(same, 3)
    End If

    If s <> vbNullString Then MsgBox s, vbOKOnly, "Error"
End Sub

以下是我调用该函数的地方:

def decode():
    # Only allocate part of the gpu memory when predicting.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=config) as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1

        enc_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d.enc" % gConfig['enc_vocab_size'])
        dec_vocab_path = os.path.join(gConfig['working_directory'],"vocab%d.dec" % gConfig['dec_vocab_size'])

        enc_vocab, _ = data_utils.initialize_vocabulary(enc_vocab_path)
        _, rev_dec_vocab = data_utils.initialize_vocabulary(dec_vocab_path)

        # !!! This is the function that I'm trying to call. !!!
        def answersqs(sentence):
            token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), enc_vocab)
            bucket_id = min([b for b in xrange(len(_buckets))
                            if _buckets[b][0] > len(token_ids)])
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                            target_weights, bucket_id, True)
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            return " ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs])

这是最后一部分:

def resp(self, msg):
    def p():
        if len(msg) > 0:
            # If I try to do decode().answersqs(msg), it starts a new session.
            ansr = answersqs(msg)
            ansrbox = Message()
            ansrbox.ids.mlab.text = str(ansr)
            ansrbox.ids.mlab.color = (1, 1, 1)
            ansrbox.pos_hint = {'x': 0}
            ansrbox.source = './icons/ansr_box.png'
            self.root.ids.chatbox.add_widget(ansrbox)
            self.root.ids.scrlv.scroll_to(ansrbox)

    threading.Thread(target=p).start()

另外,在我在Android上移植会话之前,我应该将会话从GPU更改为CPU吗?

提前谢谢!

1 个答案:

答案 0 :(得分:1)

你应该有两个变量图和会话。

加载模型时,您可以执行以下操作:

graph = tf.Graph()
session = tf.Session(config=config)
with graph.as_default(), session.as_default():
  # The reset of your model loading code.

当您需要做出预测时:

with graph.as_default(), session.as_default():
  return session.run([your_result_tensor])

会话会被加载并且在内存中,您只需告诉系统您要运行的上下文。

在代码中移动def answersqs以外的部分。它应该自动绑定到周围函数的图形和会话(但你需要在with之外使它们可用)。

对于第二部分,通常如果您按照指南操作,导出的模型应该没有硬件绑定信息,当您加载它时,tensorflow将找出一个好的位置(如果可用且可能是GPU,则可能是GPU)。