我是tensorflow的新手,目前正在tensorflow中编写一个对象检测代码:
所以这就是我正在做的事情
1)读取具有图像名称,标签/类,边界框详细信息的.csv文件。 (read_from_csv)。
2)将标签从csv转换为onehot编码。(to_onehot)在此函数中,label.eval(session = sess)会导致系统挂起。
我无法理解这个问题。我写的代码如下。请帮忙
global data_classes
data_classes = ["class1", "class2", "class3"]
#Model-Parameters
batch_size = 32
image_width = 640
image_height = 480
num_channels = 3
num_iters = 200000
validation_ratio = 0.1
num_classes = 7
learning_rate = 0.001
lr_deacay = 0.9
def num_img(csv_name):
with open(csv_name) as f:
for i, l in enumerate(f):
pass
return int(i + 1)
def read_from_csv(filename_queue):
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
#train, test = dataset_split(csv_row)
record_defaults = [[' '],[' '],[' '],[' '],[' '],[' '], [' '], [' '], [' '], [' ']]
col_Image, col_label, col_xmin, col_ymin, col_xmax, col_ymax, misc1, misc2, misc3, misc4 = tf.decode_csv(csv_row, record_defaults=record_defaults)
wd=getcwd()
/* Some Processing for Images */
# stacked values should be of same datatype
label = tf.stack([col_label])
onehot = to_onehot(label)
print('image',image.get_shape())
print('label', label.get_shape())
return image, label
def input_pipeline(batch_size, num_epochs):
filename_queue = tf.train.string_input_producer([args.dataset], num_epochs=num_epochs, shuffle=True)
image, label = read_from_csv(filename_queue)
min_after_dequeue = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
print('image',label_batch.get_shape())
return image_batch, onehot
def to_onehot(label,nclasses = num_classes):
assert tf.get_default_session() is sess
global data_classes
print('to_onehot', label)
**l = label.eval(session=sess) /* This eval is giving error */**
print(l)
indices = data_classes.index(l[0])
print(indices)
outlabels = tf.zeros([label.get_shape().as_list()[0],nclasses])
/* Some Processing */
return outlabels
with tf.device('/cpu:0'):
with sess.as_default():
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
global file_length
file_length = num_img(args.dataset) - 1
print(file_length)
tf.report_uninitialized_variables(name='uninitialized_variable')
images, labels = input_pipeline(batch_size, 1)
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
image_batch, label_batch = sess.run([images, labels])
print(label_batch)
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)
print("Network Architecture")
/*Some Processing */
答案 0 :(得分:-1)
您可能想通过调用tf.train.start_queue_runners(sess)开始训练队列