Tensorflow奇怪的标签尺寸

时间:2016-07-28 23:09:55

标签: tensorflow

我已经构建了一个tensorflow转换网并将其保存到检查点文件中。在eval程序中解压缩后,我收到一个错误:targets [0]超出范围

但我设法通过打印标签和前向传播结果来调试程序。打印后我发现标签的尺寸为[0,40],与前向传播结果相比,这不应该发生[1,4]。我绝对肯定数据没有被篡改或损坏。它与cifar 10的文件格式相同。网络似乎很好。我认为这是我拯救他们的方式。我没有放入检查点文件或其中任何一个。

这是该计划主要部分的代码:

def train():
    with tf.Session() as sess:
        images, labels = Process.inputs()

        forward_propgation_results = Process.forward_propagation(images)

        train_loss, cost = Process.error(forward_propgation_results, labels)

        image_summary_t = tf.image_summary(images.name, images, max_images=1)

        summary_op = tf.merge_all_summaries()

        init = tf.initialize_all_variables()

        saver = tf.train.Saver()

        sess.run(init)

        saver = tf.train.Saver(tf.all_variables())

        tf.train.start_queue_runners(sess = sess)

        train_dir = "/Users/Zanhuang/Desktop/NNP/model.ckpt"

        summary_writer = tf.train.SummaryWriter(train_dir, sess.graph)

        for step in range(50):
            start_time = time.time()
            print(sess.run([train_loss, cost]))
            duration = time.time() - start_time
            if step % 1 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = ('%s: step %d, (%.1f examples/sec; %.3f ''sec/batch)')
                print (format_str % (datetime.now(), step, examples_per_sec, sec_per_batch))

                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                if step % 30 == 0:
                    checkpoint_path = train_dir
                    saver.save(sess, checkpoint_path, global_step=step)


def main(argv = None):
    train()

if __name__ == '__main__':
  tf.app.run()

另外:在该计划的早期阶段,我对一个标签进行了热门编码并且没有问题进行过培训。所以这不应该出现。

更新

我在标签大小上运行了一个for循环50次迭代:

我收到了:

[40]
[120]
[120]
[92]
[48]
[92]
[120]
[120]
[92]
[48]
[92]
[120]
[92]
[92]
[48]
[100]
[120]
[48]
[48]
[120]
[48]
[92]
[48]
[48]
[48]
[48]
[120]
[120]
[48]
[48]
[48]
[48]
[92]
[120]
[48]
[48]
[92]
[120]
[48]
[92]
[48]
[48]
[120]
[48]
[120]
[48]
[92]
[92]
[120]
[48]

以下是我如何生成培训数据:

import numpy as np
import cPickle as pk
from PIL import Image
import os

def loadImage(filename):
    return Image.open(filename)


def pickle_data(filename, data, mode='wb'):
    with open(filename, mode) as file:
        pk.dump(data, file)


def unpickle_data(filename, mode = 'rb'):
    with open(filename, mode) as file:
        data = pk.load(file)
    return data


def loadAllPic():
    dict = {}
    imgdata = []
    imglabel = []

    for file in glob.glob('*.jpg'):
        print file
        img = loadImage(file)
        rawdata = img.load()
        redchannel = [rawdata[x, y][0] for x in range(img.width) for y in range(img.height)]
        greenchannel = [rawdata[x, y][1] for x in range(img.width) for y in range(img.height)]
        bluechannel = [rawdata[x, y][2] for x in range(img.width) for y in range(img.height)]
        nparray = np.array(redchannel + greenchannel + bluechannel)
        imgdata.append(nparray)
        imglabel.append("Gleason_4")
        dict['data'] = imgdata
        dict['labels'] = imglabel
        return dict


def main():
    dict = loadAllPic()
    pickle_data('Prostate_Cancer_Data2.binary', dict)

    data = unpickle_data('Prostate_Cancer_Data2.binary')
    print data.viewkeys()
    print data['labels'][:100]
    print data['data'][0]



if __name__ == '__main__':
    main()

以上是针对单个标签的。有四个程序。每个都为不同的标签定制。

0 个答案:

没有答案