Tensorflow RNN成本高昂

时间:2017-08-14 17:58:22

标签: machine-learning tensorflow computer-vision rnn

以下RNN模型减少了前一个或两个时期的损失,然后在6的成本附近波动。这似乎模型是如此随机而根本没有学习。我将学习率从0.1改为0.0001并没有帮助。数据由输入管道提供,与其他模型一起工作正常,因此提取标签和图像的功能不在此处。我已经看过很多次了,但仍然无法找到它的错误。这是代码:

n_steps = 224
n_inputs = 224
learning_rate = 0.00015
batch_size = 256 # n_neurons
epochs = 100
num_batch = int(len(trainnames)/batch_size)
keep_prob = tf.placeholder(tf.float32)

# TRAIN QUEUE
train_queue = tf.RandomShuffleQueue(len(trainnames)*1.5, 0, [tf.string, tf.float32], shapes=[[],[num_labels,]])

enqueue_train = train_queue.enqueue_many([trainnames, train_label])

train_image, train_image_label = train_queue.dequeue()

train_image = read_image_file(train_image)

train_batch, train_label_batch = tf.train.batch(
    [train_image, train_image_label],
    batch_size=batch_size,
    num_threads=1,
    capacity=10*batch_size,
    enqueue_many=False,
    shapes=[[224,224], [num_labels,]],
    allow_smaller_final_batch=True
)

train_close = train_queue.close()



def RNN(inputs, reuse):
    with tf.variable_scope('cells', reuse=reuse):
        basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=batch_size, reuse=reuse)

    with tf.variable_scope('rnn'):
        outputs, states = tf.nn.dynamic_rnn(basic_cell, inputs, dtype=tf.float32)

    fc_drop = tf.nn.dropout(states, keep_prob)

    logits = tf.contrib.layers.fully_connected(fc_drop, num_labels, activation_fn=None)

    return logits

#Training
with tf.name_scope("cost_function") as scope:
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=train_label_batch, logits=RNN(train_batch, reuse=None)))
    train_step = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(cost)


cost_summary = tf.summary.scalar("cost_function", cost)
file_writer = tf.summary.FileWriter(logdir)

#Session
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord, start=True)

    step = 0
    for epoch in range(epochs):
        sess.run(enqueue_train)
        for batch in range(num_batch):
            if step % 100 == 0:
                summary_str = cost_summary.eval(feed_dict={keep_prob: 1.0})
                file_writer.add_summary(summary_str, step)
            else:
                sess.run(train_step, feed_dict={keep_prob: 0.5})
            step += 1
    sess.run(train_close)

    coord.request_stop()
    coord.join(threads)

    file_writer.close()

0 个答案:

没有答案