Tensorflow RNN误差为20%

时间:2017-04-06 08:57:39

标签: python-3.x tensorflow


我创建了我的第一个张量流神经元网络,最初用于生成序列。它产生了奇怪的输出所以我简化了很多,看它是否只有5个输入和5个输出类别才能达到0%的错误率。
不知何故它似乎根本没有反向传播,因为它停留在20没有移动的%错误率。所以,如果有人能指出我的错误,我提前谢谢你:)
干杯

import numpy as np
import tensorflow as tf
import sys

trainingInputs = [
[[0],[0],[0],[0]],
[[1],[0],[0],[0]],
[[0],[1],[0],[0]],
[[0],[0],[1],[0]],
[[0],[0],[0],[1]]]
trainingOutputs = [
[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1],
[0,0,0,0]]

data = tf.placeholder(tf.float32, [None, len(trainingInputs[0]),1])
target = tf.placeholder(tf.float32, [None, len(trainingOutputs[0])])
num_hidden = 24
cell = tf.contrib.rnn.LSTMCell(num_hidden,state_is_tuple=True)
val, _ = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)
val = tf.transpose(val, [1, 0, 2])
last = tf.gather(val, int(val.get_shape()[0]) - 1)
weight = tf.Variable(tf.truncated_normal([num_hidden, int(target.get_shape()[1])]))
bias = tf.Variable(tf.constant(0.1, shape=[target.get_shape()[1]]))
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias)
cross_entropy = -tf.reduce_sum(target * tf.log(tf.clip_by_value(prediction,1e-10,1.0)))
optimizer = tf.train.GradientDescentOptimizer(0.01)
minimize = optimizer.minimize(cross_entropy)
mistakes = tf.not_equal(tf.argmax(target, 1), tf.argmax(prediction, 1))
error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)

batch_size = 1
no_of_batches = int((len(trainingInputs)) / batch_size)

def trainNetwork():
    epoch = 1000
    for i in range(epoch):
        ptr = 0
        for j in range(no_of_batches):
            inp, out = trainingInputs[ptr:ptr+batch_size], trainingOutputs[ptr:ptr+batch_size]
            ptr+=batch_size
            sess.run(minimize, feed_dict={data: inp, target: out})


def generateOutput():
    incorrect = sess.run(error,{data: trainingInputs, target: trainingOutputs})
    sys.stdout.write('error {:3.1f}%'.format(100 * incorrect) + "\n")
    sys.stdout.flush()

for i in range(200):
    trainNetwork()
    generateOutput()

sess.close()

0 个答案:

没有答案