张量数据类型为字符串?

时间:2017-02-27 13:45:29

标签: tensorflow

我有以下Tensorflow代码:

import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tensorflow as tf

image_width  = 202
image_height = 180
num_channels = 3

filenames = tf.train.match_filenames_once("./train/Resized/*.jpg")

def label(label_string):
    if label_string == 'cat': label = [1,0]
    if label_string == 'dog': label = [0,1]

    return label

def read_image(filename_queue):
    image_reader = tf.WholeFileReader()
    key, image_filename = image_reader.read(filename_queue)
    image = tf.image.decode_jpeg(image_filename)
    image.set_shape((image_height, image_width, 3))

    name = os.path.basename(image_filename) # example "dog.2148.jpg"
    s = name.split('.')
    label_string = s[0]
    label = label(label_string)

    return image, label

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
    image, label = read_image(filename_queue)
    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    return image_batch, label_batch

image_batch, label_batch = input_pipeline(filenames, 10)

最后一条语句失败,出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-21-0224ec735c33> in <module>()
----> 1 image_batch, label_batch = input_pipeline(filenames, 10)

<ipython-input-20-277e29dc1ae3> in input_pipeline(filenames, batch_size, num_epochs)
      1 def input_pipeline(filenames, batch_size, num_epochs=None):
      2     filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
----> 3     image, label = read_image(filename_queue)
      4     min_after_dequeue = 1000
      5     capacity = min_after_dequeue + 3 * batch_size

<ipython-input-19-ffe4ec8c3e25> in read_image(filename_queue)
      5     image.set_shape((image_height, image_width, 3))
      6 
----> 7     name = os.path.basename(image_filename) # example "dog.2148.jpg"
      8     s = name.split('.')
      9     label_string = s[0]

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in basename(p)
    230 def basename(p):
    231     """Returns the final component of a pathname"""
--> 232     return split(p)[1]
    233 
    234 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in split(p)
    202 
    203     seps = _get_bothseps(p)
--> 204     d, p = splitdrive(p)
    205     # set i to index beyond p's last slash
    206     i = len(p)

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in splitdrive(p)
    137 
    138     """
--> 139     if len(p) >= 2:
    140         if isinstance(p, bytes):
    141             sep = b'\\'

TypeError: object of type 'Tensor' has no len()

我认为问题与Tensor数据类型与字符串数据类型有关。我怎样才能正确地向os.path.basename函数指出image_filename是一个字符串?

1 个答案:

答案 0 :(得分:0)

问题是match_filenames_once返回

  

初始化为匹配模式的文件列表的变量。

(见这里:https://www.tensorflow.org/api_docs/python/tf/train/match_filenames_once)。

os.path.basename和string.split是对字符串有效的函数,而不是在张量上。

我建议你做的是将图像加载到张量流管道之外,这样我觉得标签更容易。