Tensorflow tf.data.Dataset将字符串张量转换为浮点张量

时间:2018-08-20 13:13:21

标签: python tensorflow

我想建立一个数据输入管道来读取边界框信息。 因此,我有多个.txt文件,它们分别在每行中存储有关xywidthheight的信息,例如:

952  607    9   18
947 1176   14   12
937  228   17   22
895 1118   66   53
804  596   12   13
651  722   13    8
667  306   28   51
586 1148   20   32
231  280   33   31
859  629  102  172
806  486  155  111
487  506   55   69
263  476  372  339
4  589  114  106
273  724  164  192
4    4  350  292

所有文件名都保存在list

filenames_bb = input_tools.get_required_filenames(args.dataset_dir, "train", params)

我使用tf.data.Dataset构建输入管道

dataset = tf.data.Dataset.from_tensor_slices(filenames_bb)
dataset = dataset.map(parse_fnc, params.num_parallel_calls)

现在我的问题是如何实施parse_fnc?我想要的是形状为[batch_size, number_of_bounding_boxes, 4]的张量。当前parse_fnc如下所示:

def parse_fnc(filenames):
    bb = tf.read_file(filenames)

    return bb

它返回张量shape=(?,)dtype=string,但是如何将其转换为具有所需尺寸的float张量?

1 个答案:

答案 0 :(得分:0)

您可以通过调用tf.TextLineDataset创建包含一个或多个文本文件行的数据集。然后,您可以使用tf.string_splittf.string_to_number获得浮点值。

def parse_fnc(line):
    string_vals = tf.string_split([line]).values
    return tf.string_to_number(string_vals, tf.float32)

string_ds = tf.data.TextLineDataset('./data.txt')
float_ds = string_ds.map(map_func=parse_fnc)

此示例从一个文件创建数据集,但是您可以提供多个文本文件作为输入。