加入参差不齐的字符张量

时间:2020-09-29 12:27:19

标签: python tensorflow2.0

我的字符张量参差不齐(复制/可复制代码以重现):

flat_values = [
    b'7', b'2', b'1', b'0', b'4', b'1', b'4', b'5', b'0', b'6', b'0',
    b'1', b'5', b'7', b'8', b'4', b'6', b'6', b'5', b'4', b'0', b'7',
    b'4', b'0', b'1', b'3', b'1', b'3', b'4', b'7', b'2', b'7', b'1',
    b'2', b'1', b'1', b'7', b'4', b'2', b'3', b'5', b'1', b'2', b'4',
    b'4', b'6', b'3', b'5', b'5', b'6', b'0', b'4', b'1', b'5', b'7',
    b'8', b'3', b'7', b'4', b'6', b'4', b'3', b'0', b'7', b'0', b'2',
    b'1', b'7', b'3', b'2', b'7', b'7', b'6', b'2', b'7', b'8', b'4',
    b'7', b'3', b'6', b'1', b'3', b'6', b'3', b'1', b'4', b'1', b'7',
    b'6', b'6', b'0', b'5', b'4', b'2', b'1', b'4', b'8', b'7', b'3',
    b'7', b'4', b'4', b'2', b'5', b'4', b'7', b'6', b'7', b'0', b'5',
    b'8', b'5', b'6', b'6', b'5', b'7', b'8', b'1', b'0', b'1', b'6',
    b'4', b'6', b'7', b'3', b'1', b'7', b'1', b'8', b'2', b'0', b'2',
]

row_lengths = [
    6, 4, 4, 5, 6, 6, 6, 6, 6, 5, 5, 6,
    5, 5, 6, 5, 5, 4, 4, 4, 5, 6, 6, 6, 6,
]

x = tf.RaggedTensor.from_nested_row_lengths(
    flat_values,
    (row_lengths,),
)

我想将行作为字符串连接,但是我想在图中完成。我可以通过评估张量来实现:

>>> [''.join([c.decode() for c in i]) for i in x.to_list()]
['721041',
 '4506',
 '0157',
 '84665',
 '407401',
 '313472',
 '712117',
 '423512',
 '446355',
 '60415',
 '78374',
 '643070',
 '21732',
 '77627',
 '847361',
 '36314',
 '17660',
 '5421',
 '4873',
 '7442',
 '54767',
 '058566',
 '578101',
 '646731',
 '718202']

但是因为这是我的网络的输出(以图模式训练),所以我希望能够以张量流表示此操作,以便可以在验证步骤中对其进行评估。我尝试过的两件事不起作用:

>>> tf.strings.join(x)
InvalidArgumentError: Input shapes do not match: [6] vs. [4] [Op:StringJoin]

>>> tf.ragged.map_flat_values(tf.strings.join, x)
ValueError: Shape () must have rank at least 1

令人沮丧的是,documentation for tf.strings.join提到参差不齐的张量,但没有给出示例。我想念什么?看来应该有一个明显的解决方案。

1 个答案:

答案 0 :(得分:0)

我知道了...有一个tf.strings.reduce_join可以做到:

>>> tf.strings.reduce_join(x, axis=1)
<tf.Tensor: shape=(25,), dtype=string, numpy=
array([b'721041', b'4506', b'0157', b'84665', b'407401', b'313472',
       b'712117', b'423512', b'446355', b'60415', b'78374', b'643070',
       b'21732', b'77627', b'847361', b'36314', b'17660', b'5421',
       b'4873', b'7442', b'54767', b'058566', b'578101', b'646731',
       b'718202'], dtype=object)>