Tensorflow eval挂起

时间:2017-05-29 09:13:21

标签: python csv tensorflow

我是tensorflow的新手,目前正在tensorflow中编写一个对象检测代码:

所以这就是我正在做的事情

1)读取具有图像名称,标签/类,边界框详细信息的.csv文件。 (read_from_csv)。

2)将标签从csv转换为onehot编码。(to_onehot)在此函数中,label.eval(session = sess)会导致系统挂起。

我无法理解这个问题。我写的代码如下。请帮忙

global data_classes
data_classes = ["class1", "class2", "class3"]

#Model-Parameters
batch_size = 32
image_width = 640
image_height = 480
num_channels = 3
num_iters = 200000
validation_ratio = 0.1
num_classes = 7
learning_rate = 0.001
lr_deacay = 0.9


def num_img(csv_name):
    with open(csv_name) as f:
        for i, l in enumerate(f):
            pass
    return int(i + 1)


def read_from_csv(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=1)
    _, csv_row = reader.read(filename_queue)
    #train, test = dataset_split(csv_row)
    record_defaults = [[' '],[' '],[' '],[' '],[' '],[' '], [' '], [' '], [' '], [' ']]
    col_Image, col_label, col_xmin, col_ymin, col_xmax, col_ymax, misc1, misc2, misc3, misc4 = tf.decode_csv(csv_row, record_defaults=record_defaults)
    wd=getcwd()
    /* Some Processing for Images */
    # stacked values should be of same datatype
    label = tf.stack([col_label])
    onehot = to_onehot(label)
    print('image',image.get_shape())
    print('label', label.get_shape())
    return image, label


def input_pipeline(batch_size, num_epochs):
    filename_queue = tf.train.string_input_producer([args.dataset], num_epochs=num_epochs, shuffle=True) 
    image, label = read_from_csv(filename_queue)
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
    print('image',label_batch.get_shape())
    return image_batch, onehot 

def to_onehot(label,nclasses = num_classes):
    assert tf.get_default_session() is sess
    global data_classes
    print('to_onehot', label)
    **l = label.eval(session=sess) /* This eval is giving error */**
    print(l)
    indices = data_classes.index(l[0]) 
    print(indices)
    outlabels = tf.zeros([label.get_shape().as_list()[0],nclasses])
    /* Some Processing */
    return outlabels



with tf.device('/cpu:0'):
    with sess.as_default():
        init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
        sess.run(init_op)
        global file_length
    file_length = num_img(args.dataset) - 1
        print(file_length)
        tf.report_uninitialized_variables(name='uninitialized_variable')
        images, labels = input_pipeline(batch_size, 1)


    # start populating filename queue
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                image_batch, label_batch = sess.run([images, labels])

        print(label_batch)
        except tf.errors.OutOfRangeError:
            print('Done training, epoch reached')
        finally:
            coord.request_stop()

    coord.join(threads)


    print("Network Architecture")
/*Some Processing */

1 个答案:

答案 0 :(得分:-1)

您可能想通过调用tf.train.start_queue_runners(sess)开始训练队列