Tensorflow批处理:将结果保留为字符串

时间:2018-11-13 09:54:01

标签: python tensorflow

这个简单的程序

import tensorflow as tf

input = 'string'
batch = tf.train.batch([tf.constant(input)], batch_size=1)
with tf.Session() as sess:
    tf.train.start_queue_runners()
    output, = sess.run(batch)

    print(1, input, output)
    print(2, str(output, 'utf-8'))
    print(3, input.split('i'))
    print(4, str(output, 'utf-8').split('i'))
    print(5, output.split('i'))

打印

  

1个字符串b'string'
  2串
  3 ['str','ng']
  4 ['str','ng']
  ERROR:tensorflow:QueueRunner中的异常:会话已关闭。
  打印(5,output.split('i'))
  TypeError:需要一个类似字节的对象,而不是'str'

为什么输入的结果不是字符串列表?

好,@ jdehesa explained 为什么,但不是如何“修复”它。我可以应用 bytes.decode()

output, = map(bytes.decode, sess.run(batch))

并且存在 tf.map_fn()应该在张量上执行相同的操作。唯一的问题是如何在我的方案中使用它?


PS :实际上,错误消息也令人困惑。问题在于我们提供的是字节对象,而不是字符串。但是 TypeError 却恰恰相反。

PPS :由于@jdehesa,解释了错误消息:它是关于 split()的参数,而不是对象。 output.split(b'i')效果很好!

1 个答案:

答案 0 :(得分:2)

问题在于outputbytes对象,因为TensorFlow tf.string张量确实由bytes组成。但是,然后您尝试将splitstr分隔符一起使用,这就是它抱怨的原因。试试:

output.split(b'i')

或:

output.decode().split('i')