我有以下Tensorflow代码:
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tensorflow as tf
image_width = 202
image_height = 180
num_channels = 3
filenames = tf.train.match_filenames_once("./train/Resized/*.jpg")
def label(label_string):
if label_string == 'cat': label = [1,0]
if label_string == 'dog': label = [0,1]
return label
def read_image(filename_queue):
image_reader = tf.WholeFileReader()
key, image_filename = image_reader.read(filename_queue)
image = tf.image.decode_jpeg(image_filename)
image.set_shape((image_height, image_width, 3))
name = os.path.basename(image_filename) # example "dog.2148.jpg"
s = name.split('.')
label_string = s[0]
label = label(label_string)
return image, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
image, label = read_image(filename_queue)
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return image_batch, label_batch
image_batch, label_batch = input_pipeline(filenames, 10)
最后一条语句失败,出现以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-21-0224ec735c33> in <module>()
----> 1 image_batch, label_batch = input_pipeline(filenames, 10)
<ipython-input-20-277e29dc1ae3> in input_pipeline(filenames, batch_size, num_epochs)
1 def input_pipeline(filenames, batch_size, num_epochs=None):
2 filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
----> 3 image, label = read_image(filename_queue)
4 min_after_dequeue = 1000
5 capacity = min_after_dequeue + 3 * batch_size
<ipython-input-19-ffe4ec8c3e25> in read_image(filename_queue)
5 image.set_shape((image_height, image_width, 3))
6
----> 7 name = os.path.basename(image_filename) # example "dog.2148.jpg"
8 s = name.split('.')
9 label_string = s[0]
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in basename(p)
230 def basename(p):
231 """Returns the final component of a pathname"""
--> 232 return split(p)[1]
233
234
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in split(p)
202
203 seps = _get_bothseps(p)
--> 204 d, p = splitdrive(p)
205 # set i to index beyond p's last slash
206 i = len(p)
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in splitdrive(p)
137
138 """
--> 139 if len(p) >= 2:
140 if isinstance(p, bytes):
141 sep = b'\\'
TypeError: object of type 'Tensor' has no len()
我认为问题与Tensor数据类型与字符串数据类型有关。我怎样才能正确地向os.path.basename函数指出image_filename是一个字符串?
答案 0 :(得分:0)
问题是match_filenames_once返回
初始化为匹配模式的文件列表的变量。
(见这里:https://www.tensorflow.org/api_docs/python/tf/train/match_filenames_once)。
os.path.basename和string.split是对字符串有效的函数,而不是在张量上。
我建议你做的是将图像加载到张量流管道之外,这样我觉得标签更容易。