在TensorFlow中对我自己的数据集进行培训的玩具示例

时间:2016-10-31 06:24:48

标签: tensorflow

我正在尝试创建一个在我自己的图像上训练小网络的玩具示例。该网络与https://www.tensorflow.org/versions/r0.11/tutorials/mnist/pros/index.html

相同

这是代码

import tensorflow as tf


sess = tf.InteractiveSession() 

#####reading images

filenames=['images/000001.jpg','images/000002.jpg','images/000003.jpg','images/000004.jpg']
labels=[[1.,0.,0.,0., 0., 0., 0., 0., 0., 0.], [1.,0.,0.,0., 0., 0., 0., 0., 0., 0.], [1.,0.,0.,0., 0., 0., 0., 0., 0., 0.], [1.,0.,0.,0., 0., 0., 0., 0., 0., 0.]]

filename_queue=tf.train.string_input_producer(filenames)

reader=tf.WholeFileReader()
filename, content = reader.read(filename_queue)
images=tf.image.decode_jpeg(content, channels=3)

images=tf.cast(images, tf.float32)
resized_images=tf.image.resize_images(images, [28, 28])

image_batch, label_batch=tf.train.batch([resized_images, labels], batch_size=2)

######weights and network

dense_w={
    "w_conv1": tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1), name="w_conv1"),
    "b_conv1": tf.Variable(tf.constant(0.1,shape=[32]), name="b_conv1"),
    "w_conv2": tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1), name="w_conv2"),
    "b_conv2": tf.Variable(tf.constant(0.1,shape=[64]), name="b_conv2"),
    "w_fc1": tf.Variable(tf.truncated_normal([7*7*64,1024],stddev=0.1), name="w_fc1"),
    "b_fc1": tf.Variable(tf.constant(0.1,shape=[1024]), name="b_fc1"),
    "w_fc2": tf.Variable(tf.truncated_normal([1024,10],stddev=0.1), name="w_fc2"),
    "b_fc2": tf.Variable(tf.constant(0.1,shape=[10]), name="b_fc2")
}



def dense_cnn_model(weights):
    def conv2d(x, W):
        return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

    def max_pool_2x2(x):
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                              strides=[1, 2, 2, 1], padding='SAME')

    x_image = tf.reshape(x, [-1,28,28,1])
    h_conv1 = tf.nn.relu(conv2d(x_image, weights["w_conv1"]) + weights["b_conv1"])
    h_pool1 = max_pool_2x2(h_conv1)
    h_conv2 = tf.nn.relu(conv2d(h_pool1, weights["w_conv2"]) + weights["b_conv2"])
    h_pool2 = max_pool_2x2(h_conv2)
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, weights["w_fc1"]) + weights["b_fc1"])
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
    y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, weights["w_fc2"]) + weights["b_fc2"])
    return y_conv

# Construct a dense model
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
keep_prob = tf.placeholder("float")

y_conv = dense_cnn_model(dense_w)

   ####################### Training


cross_entropy = -tf.reduce_sum(y_*tf.log(tf.clip_by_value(y_conv,1e-10,1.0)))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
image_batch_eval, label_batch_eval=image_batch.eval(), label_batch.eval()
train_step.run(feed_dict={x: image_batch_eval, y_: label_batch_eval, keep_prob: 0.5})

程序在倒数第二行的批次评估中冻结(好像进入无限循环)。这次评估有什么问题?

1 个答案:

答案 0 :(得分:0)

filename_queue=tf.train.string_input_producer(filenames)使用您的文件名构建一个队列。所以你应该在调用之前启动队列:

image_batch_eval, label_batch_eval=image_batch.eval(), label_batch.eval()

之前添加:

tf.train.start_queue_runners(sess=sess)

您可以找到有关队列here的更多详细信息。