/创建用于训练图像的数据集对象-张量流

时间:2020-02-07 08:09:00

标签: python tensorflow

我正在尝试为我的训练数据集图像创建tf.data.Dataset对象,但出现错误。数据集有5个类别,每个类别都有大约50张图像。目录结构如下:

directory dataset:
class1    class2    class3    class4    class5

抛出错误的代码片段为:

>>> def get_label(file_path):      
...     # convert the path to the list of path components
...     parts = tf.strings.split(file_path, os.path.sep)
...     # the second to the last is the class-directory
...     return parts[-2]

>>> list_ds = tf.data.Dataset.list_files('./dataset/*/*.jpg'))

>>> for f in list_ds.take(1):
...   print(get_label(f))

在行parts = tf.strings.split(file_path, os.path.sep)上引发以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be a vector, got shape: [] [Op:StringSplitV2]
我正在使用的

tf版本是2.0。我正在遵循tensorflow文档示例,但仍然出现此错误。我是否将错误的对象传递给tf.strings.split

0 个答案:

没有答案
相关问题