我的字符张量参差不齐(复制/可复制代码以重现):
flat_values = [
b'7', b'2', b'1', b'0', b'4', b'1', b'4', b'5', b'0', b'6', b'0',
b'1', b'5', b'7', b'8', b'4', b'6', b'6', b'5', b'4', b'0', b'7',
b'4', b'0', b'1', b'3', b'1', b'3', b'4', b'7', b'2', b'7', b'1',
b'2', b'1', b'1', b'7', b'4', b'2', b'3', b'5', b'1', b'2', b'4',
b'4', b'6', b'3', b'5', b'5', b'6', b'0', b'4', b'1', b'5', b'7',
b'8', b'3', b'7', b'4', b'6', b'4', b'3', b'0', b'7', b'0', b'2',
b'1', b'7', b'3', b'2', b'7', b'7', b'6', b'2', b'7', b'8', b'4',
b'7', b'3', b'6', b'1', b'3', b'6', b'3', b'1', b'4', b'1', b'7',
b'6', b'6', b'0', b'5', b'4', b'2', b'1', b'4', b'8', b'7', b'3',
b'7', b'4', b'4', b'2', b'5', b'4', b'7', b'6', b'7', b'0', b'5',
b'8', b'5', b'6', b'6', b'5', b'7', b'8', b'1', b'0', b'1', b'6',
b'4', b'6', b'7', b'3', b'1', b'7', b'1', b'8', b'2', b'0', b'2',
]
row_lengths = [
6, 4, 4, 5, 6, 6, 6, 6, 6, 5, 5, 6,
5, 5, 6, 5, 5, 4, 4, 4, 5, 6, 6, 6, 6,
]
x = tf.RaggedTensor.from_nested_row_lengths(
flat_values,
(row_lengths,),
)
我想将行作为字符串连接,但是我想在图中完成。我可以通过评估张量来实现:
>>> [''.join([c.decode() for c in i]) for i in x.to_list()]
['721041',
'4506',
'0157',
'84665',
'407401',
'313472',
'712117',
'423512',
'446355',
'60415',
'78374',
'643070',
'21732',
'77627',
'847361',
'36314',
'17660',
'5421',
'4873',
'7442',
'54767',
'058566',
'578101',
'646731',
'718202']
但是因为这是我的网络的输出(以图模式训练),所以我希望能够以张量流表示此操作,以便可以在验证步骤中对其进行评估。我尝试过的两件事不起作用:
>>> tf.strings.join(x)
InvalidArgumentError: Input shapes do not match: [6] vs. [4] [Op:StringJoin]
>>> tf.ragged.map_flat_values(tf.strings.join, x)
ValueError: Shape () must have rank at least 1
令人沮丧的是,documentation for tf.strings.join
提到参差不齐的张量,但没有给出示例。我想念什么?看来应该有一个明显的解决方案。
答案 0 :(得分:0)
我知道了...有一个tf.strings.reduce_join
可以做到:
>>> tf.strings.reduce_join(x, axis=1)
<tf.Tensor: shape=(25,), dtype=string, numpy=
array([b'721041', b'4506', b'0157', b'84665', b'407401', b'313472',
b'712117', b'423512', b'446355', b'60415', b'78374', b'643070',
b'21732', b'77627', b'847361', b'36314', b'17660', b'5421',
b'4873', b'7442', b'54767', b'058566', b'578101', b'646731',
b'718202'], dtype=object)>