我的TensorFlow应用程序的错误,我似乎无法找到它

时间:2018-06-09 03:30:27

标签: python tensorflow

我对TensorFlow很新。我已经从教程中吸收了我所能做到的事情,但我却陷入了制造自己的困境。我通过做而不是阅读来学习,所以我希望能有更好的东西,我可以更好地理解TensorFlow。

我收到此错误:“标签的形状(收到(32,32))应该等于logits的形状,除了最后一个维度(收到(1,32,32))。”我摆弄了代码无济于事。

该程序应该读取图像并吐出“掩码”图像,稍后将用于剪切背景。白色像素表示“保持此”,黑色像素表示“不保留此”。

pm.main是一个文件,其中包含一些用于加载图像的函数,另一个用于创建模型。我试图从train.py(用于训练)和eval.py(用于运行该东西)中提取这些东西以保持我的代码更清洁。

“pm.main”:

import tensorflow as tf
import matplotlib.image as mpimg
from PIL import Image as img
import numpy as np
import os

def rgb2gray(rgb):
    #return rgb
    gray = np.dot(rgb[...,:3], [0.3, 0.3, 0.3]).astype('int32')
    return gray

def load_images(dataDirectory):
    images = []
    for i in range(1, 2):
        filePrefix = f'{i:03}'
        fileName = '%s.png' % filePrefix
        maskFileName = '%sm.png' % filePrefix
        image = mpimg.imread(os.path.join(dataDirectory, fileName))
        maskImg = img.open(os.path.join(dataDirectory, maskFileName))
        maskImg.thumbnail((32, 32), img.ANTIALIAS)
        mask = np.array(maskImg)
        fixedMask = rgb2gray(mask)
        images.append([image, fixedMask])
    return images

class x():
    pass

def cnn_model_fn(input):
    input_layer = tf.reshape(input, [-1, 256, 256, 3])

    def addNewLayer(prevLayer):
        conv = tf.layers.conv2d(inputs=prevLayer, filters=32, kernel_size=[3,3], padding='same', activation=tf.nn.relu)
        pool = tf.layers.max_pooling2d(inputs=conv, pool_size=[2, 2], strides=2)
        return pool

    layer1 = addNewLayer(input_layer)
    layer2 = addNewLayer(layer1)
    layer3 = addNewLayer(layer2)
    # layer3 = 64 x 64 x 32

    #return layer3
    #flat = tf.reshape(layer3, [-1, 32 * 32 * 32])
    #dense = tf.layers.dense(inputs=flat, units=64 * 64, activation=tf.nn.relu)
    #dense = tf.layers.dense(inputs=flat, units=64, activation=tf.nn.relu)
    dense = tf.layers.dense(inputs=layer3, units=1, activation=tf.nn.relu)

    result = x()
    result.logits = dense
    #result.logits = flat
    #flat = tf.reshape(layer3, [-1, 32, 32, 1])
    #result.logits = flat
    #print('layer3: %s' % layer3.shape)

    print('logits: %s' % result.logits.shape)
    return result

和“train.py”:

import tensorflow as tf
from pm import main as pm
from PIL import Image as img
import matplotlib.pyplot as pyplot
import numpy as np
import os

currentDirectory = os.path.dirname(os.path.realpath(__file__))
dataDirectory = os.path.join(currentDirectory, 'data')

images = pm.load_images(dataDirectory)

imagePlaceholder = tf.placeholder(tf.int32, shape=[256, 256, 3])
#maskPlaceholder = tf.placeholder(tf.int32, shape=[1, 256, 256, 1])
maskPlaceholder = tf.placeholder(tf.int32, shape=[32, 32, 1])

pair = images[0]
image = pair[0] # inputs
mask = pair[1] # labels

model = pm.cnn_model_fn(image)
logits = model.logits

print(mask.shape)
print(logits.shape)

saver = tf.train.Saver()

sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
#saver.restore(sess, 'network.ckpt')


loss_op = tf.losses.sparse_softmax_cross_entropy(labels=mask, logits=logits)

result = sess.run(loss_op, feed_dict={ imagePlaceholder: image, maskPlaceholder: logits })
print(result)

resultData = result.reshape(64, 64).astype('uint8') * 255
print(resultData)

imageData = img.fromarray(resultData)
imageData.save('output.png')

saver.save(sess, 'network.ckpt')

非常感谢任何见解。

1 个答案:

答案 0 :(得分:0)

您没有提供错误堆栈跟踪,但这是我最好的猜测。

最有可能的错误来自交叉熵损失:link = "http://urbantoronto.ca/database/" driver = webdriver.Chrome() driver.get(link) wait = WebDriverWait(driver, 10) # For readability condition = (By.CSS_SELECTOR, "#project_list table tr[id^='project']") tr = wait.until(EC.presence_of_all_elements_located(condition)) # Get links, this will take a few seconds with Selenium selector = "a[href^='//urbantoronto']" links = [x.find_element_by_css_selector(selector).get_attribute('href') for x in tr] for nlink in links: driver.get(nlink) sitem = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, "h1.title"))) title = sitem.text try: desc = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, ".project-description p"))).text except Exception: desc = "" print("Title: {}\nDescription: {}\n".format(title,desc)) driver.quit() 。如果您查看spec,您会看到此函数希望tf.losses.sparse_softmax_cross_entropylabels小1等级,并且所有尺寸预期最后都必须相同(以便它知道哪个标签对应到哪一组logits。)

在您的情况下,标签似乎具有logits的形状。假设你只想摆脱(1, 32, 32),你可以1