我对TensorFlow很新。我已经从教程中吸收了我所能做到的事情,但我却陷入了制造自己的困境。我通过做而不是阅读来学习,所以我希望能有更好的东西,我可以更好地理解TensorFlow。
我收到此错误:“标签的形状(收到(32,32))应该等于logits的形状,除了最后一个维度(收到(1,32,32))。”我摆弄了代码无济于事。
该程序应该读取图像并吐出“掩码”图像,稍后将用于剪切背景。白色像素表示“保持此”,黑色像素表示“不保留此”。
pm.main是一个文件,其中包含一些用于加载图像的函数,另一个用于创建模型。我试图从train.py(用于训练)和eval.py(用于运行该东西)中提取这些东西以保持我的代码更清洁。
“pm.main”:
import tensorflow as tf
import matplotlib.image as mpimg
from PIL import Image as img
import numpy as np
import os
def rgb2gray(rgb):
#return rgb
gray = np.dot(rgb[...,:3], [0.3, 0.3, 0.3]).astype('int32')
return gray
def load_images(dataDirectory):
images = []
for i in range(1, 2):
filePrefix = f'{i:03}'
fileName = '%s.png' % filePrefix
maskFileName = '%sm.png' % filePrefix
image = mpimg.imread(os.path.join(dataDirectory, fileName))
maskImg = img.open(os.path.join(dataDirectory, maskFileName))
maskImg.thumbnail((32, 32), img.ANTIALIAS)
mask = np.array(maskImg)
fixedMask = rgb2gray(mask)
images.append([image, fixedMask])
return images
class x():
pass
def cnn_model_fn(input):
input_layer = tf.reshape(input, [-1, 256, 256, 3])
def addNewLayer(prevLayer):
conv = tf.layers.conv2d(inputs=prevLayer, filters=32, kernel_size=[3,3], padding='same', activation=tf.nn.relu)
pool = tf.layers.max_pooling2d(inputs=conv, pool_size=[2, 2], strides=2)
return pool
layer1 = addNewLayer(input_layer)
layer2 = addNewLayer(layer1)
layer3 = addNewLayer(layer2)
# layer3 = 64 x 64 x 32
#return layer3
#flat = tf.reshape(layer3, [-1, 32 * 32 * 32])
#dense = tf.layers.dense(inputs=flat, units=64 * 64, activation=tf.nn.relu)
#dense = tf.layers.dense(inputs=flat, units=64, activation=tf.nn.relu)
dense = tf.layers.dense(inputs=layer3, units=1, activation=tf.nn.relu)
result = x()
result.logits = dense
#result.logits = flat
#flat = tf.reshape(layer3, [-1, 32, 32, 1])
#result.logits = flat
#print('layer3: %s' % layer3.shape)
print('logits: %s' % result.logits.shape)
return result
和“train.py”:
import tensorflow as tf
from pm import main as pm
from PIL import Image as img
import matplotlib.pyplot as pyplot
import numpy as np
import os
currentDirectory = os.path.dirname(os.path.realpath(__file__))
dataDirectory = os.path.join(currentDirectory, 'data')
images = pm.load_images(dataDirectory)
imagePlaceholder = tf.placeholder(tf.int32, shape=[256, 256, 3])
#maskPlaceholder = tf.placeholder(tf.int32, shape=[1, 256, 256, 1])
maskPlaceholder = tf.placeholder(tf.int32, shape=[32, 32, 1])
pair = images[0]
image = pair[0] # inputs
mask = pair[1] # labels
model = pm.cnn_model_fn(image)
logits = model.logits
print(mask.shape)
print(logits.shape)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
#saver.restore(sess, 'network.ckpt')
loss_op = tf.losses.sparse_softmax_cross_entropy(labels=mask, logits=logits)
result = sess.run(loss_op, feed_dict={ imagePlaceholder: image, maskPlaceholder: logits })
print(result)
resultData = result.reshape(64, 64).astype('uint8') * 255
print(resultData)
imageData = img.fromarray(resultData)
imageData.save('output.png')
saver.save(sess, 'network.ckpt')
非常感谢任何见解。
答案 0 :(得分:0)
您没有提供错误堆栈跟踪,但这是我最好的猜测。
最有可能的错误来自交叉熵损失:link = "http://urbantoronto.ca/database/"
driver = webdriver.Chrome()
driver.get(link)
wait = WebDriverWait(driver, 10)
# For readability
condition = (By.CSS_SELECTOR, "#project_list table tr[id^='project']")
tr = wait.until(EC.presence_of_all_elements_located(condition))
# Get links, this will take a few seconds with Selenium
selector = "a[href^='//urbantoronto']"
links = [x.find_element_by_css_selector(selector).get_attribute('href') for x in tr]
for nlink in links:
driver.get(nlink)
sitem = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, "h1.title")))
title = sitem.text
try:
desc = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, ".project-description p"))).text
except Exception: desc = ""
print("Title: {}\nDescription: {}\n".format(title,desc))
driver.quit()
。如果您查看spec,您会看到此函数希望tf.losses.sparse_softmax_cross_entropy
比labels
小1等级,并且所有尺寸预期最后都必须相同(以便它知道哪个标签对应到哪一组logits。)
在您的情况下,标签似乎具有logits
的形状。假设你只想摆脱(1, 32, 32)
,你可以1