如何从tf.tensor中获取字符串值,其中dtype是字符串

时间:2019-05-14 03:37:36

标签: python tensorflow

我想使用tf.data.Dataset.list_files函数来提供数据集。
但是由于该文件不是图像,因此需要手动加载。
问题是tf.data.Dataset.list_files作为tf.tensor传递变量,而我的python代码无法处理张量。

如何从tf.tensor获取字符串值。 dtype是字符串。

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

file_path为Tensor("arg0:0", shape=(), dtype=string)

我使用tensorflow 1.13.1和eager模式。

预先感谢

2 个答案:

答案 0 :(得分:2)

您可以使用tf.py_func来包装load_audio_file()

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):
    # you should decode bytes type to string type
    print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
    return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

file_path:  clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path:  clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path:  clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

答案 1 :(得分:0)

如果您想做一些完全自定义的事情,那么将您的代码包装在 tf.py_function 中是您应该做的。请记住,这将导致性能不佳。请参阅此处的文档和示例:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map

另一方面,如果您正在做一些通用的事情,那么您不需要将代码包装在 py_function 中,而是使用 tf.strings 模块中提供的任何方法。这些方法适用于字符串张量,并提供了许多常用方法,如 split、join、len 等。这些方法不会对性能产生负面影响,它们将直接作用于张量并返回修改后的张量。

在此处查看 tf.strings 的文档:https://www.tensorflow.org/api_docs/python/tf/strings

例如,假设您想从文件名中提取标签名称,然后您可以编写如下代码:

ds.map(lambda x: tf.strings.split(x, sep='$')[1])

以上假设标签由 $ 分隔。