我已经构建了一个tensorflow转换网并将其保存到检查点文件中。在eval程序中解压缩后,我收到一个错误:targets [0]超出范围
但我设法通过打印标签和前向传播结果来调试程序。打印后我发现标签的尺寸为[0,40],与前向传播结果相比,这不应该发生[1,4]。我绝对肯定数据没有被篡改或损坏。它与cifar 10的文件格式相同。网络似乎很好。我认为这是我拯救他们的方式。我没有放入检查点文件或其中任何一个。
这是该计划主要部分的代码:
def train():
with tf.Session() as sess:
images, labels = Process.inputs()
forward_propgation_results = Process.forward_propagation(images)
train_loss, cost = Process.error(forward_propgation_results, labels)
image_summary_t = tf.image_summary(images.name, images, max_images=1)
summary_op = tf.merge_all_summaries()
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)
saver = tf.train.Saver(tf.all_variables())
tf.train.start_queue_runners(sess = sess)
train_dir = "/Users/Zanhuang/Desktop/NNP/model.ckpt"
summary_writer = tf.train.SummaryWriter(train_dir, sess.graph)
for step in range(50):
start_time = time.time()
print(sess.run([train_loss, cost]))
duration = time.time() - start_time
if step % 1 == 0:
num_examples_per_step = FLAGS.batch_size
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, (%.1f examples/sec; %.3f ''sec/batch)')
print (format_str % (datetime.now(), step, examples_per_sec, sec_per_batch))
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
if step % 30 == 0:
checkpoint_path = train_dir
saver.save(sess, checkpoint_path, global_step=step)
def main(argv = None):
train()
if __name__ == '__main__':
tf.app.run()
另外:在该计划的早期阶段,我对一个标签进行了热门编码并且没有问题进行过培训。所以这不应该出现。
更新
我在标签大小上运行了一个for循环50次迭代:
我收到了:
[40]
[120]
[120]
[92]
[48]
[92]
[120]
[120]
[92]
[48]
[92]
[120]
[92]
[92]
[48]
[100]
[120]
[48]
[48]
[120]
[48]
[92]
[48]
[48]
[48]
[48]
[120]
[120]
[48]
[48]
[48]
[48]
[92]
[120]
[48]
[48]
[92]
[120]
[48]
[92]
[48]
[48]
[120]
[48]
[120]
[48]
[92]
[92]
[120]
[48]
以下是我如何生成培训数据:
import numpy as np
import cPickle as pk
from PIL import Image
import os
def loadImage(filename):
return Image.open(filename)
def pickle_data(filename, data, mode='wb'):
with open(filename, mode) as file:
pk.dump(data, file)
def unpickle_data(filename, mode = 'rb'):
with open(filename, mode) as file:
data = pk.load(file)
return data
def loadAllPic():
dict = {}
imgdata = []
imglabel = []
for file in glob.glob('*.jpg'):
print file
img = loadImage(file)
rawdata = img.load()
redchannel = [rawdata[x, y][0] for x in range(img.width) for y in range(img.height)]
greenchannel = [rawdata[x, y][1] for x in range(img.width) for y in range(img.height)]
bluechannel = [rawdata[x, y][2] for x in range(img.width) for y in range(img.height)]
nparray = np.array(redchannel + greenchannel + bluechannel)
imgdata.append(nparray)
imglabel.append("Gleason_4")
dict['data'] = imgdata
dict['labels'] = imglabel
return dict
def main():
dict = loadAllPic()
pickle_data('Prostate_Cancer_Data2.binary', dict)
data = unpickle_data('Prostate_Cancer_Data2.binary')
print data.viewkeys()
print data['labels'][:100]
print data['data'][0]
if __name__ == '__main__':
main()
以上是针对单个标签的。有四个程序。每个都为不同的标签定制。