我正在尝试为我的训练数据集图像创建tf.data.Dataset
对象,但出现错误。数据集有5个类别,每个类别都有大约50张图像。目录结构如下:
directory dataset:
class1 class2 class3 class4 class5
抛出错误的代码片段为:
>>> def get_label(file_path):
... # convert the path to the list of path components
... parts = tf.strings.split(file_path, os.path.sep)
... # the second to the last is the class-directory
... return parts[-2]
>>> list_ds = tf.data.Dataset.list_files('./dataset/*/*.jpg'))
>>> for f in list_ds.take(1):
... print(get_label(f))
在行parts = tf.strings.split(file_path, os.path.sep)
上引发以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be a vector, got shape: [] [Op:StringSplitV2]
我正在使用的 tf版本是2.0
。我正在遵循tensorflow文档示例,但仍然出现此错误。我是否将错误的对象传递给tf.strings.split
?