如何使用TBTT处理批处理中的可变序列长度

时间:2018-06-18 21:42:12

标签: python tensorflow lstm rnn

当我想在时间上使用截断反向传播时,我如何处理批次中具有不同序列长度的样本?
我的意思是当我用新的Timesteps提供我的TF会话并且批处理中的样本具有不同的长度时,一个样本可能没有任何时间步长可以进给。
所以我是零填充它还是让我用另一个替换该样本或者是否可以在序列长度向量中提供序列长度0?

我也不完全确定我的代码是否具有Truncated Backprop Through Time是正确的。所以如果任何人都可以看看它并告诉我TBBT是否正确实施,那就太好了。
我的实际代码目前用于MNIST的培训,TBTT的准确率下降了约7%,这是正常情况还是我犯了错误?我不能完全确定何时必须将输出状态作为输入。

我的最后一个问题是如何预测TBBT的准确性?它是序列精度的平均值吗?
后来我想在大型数据集上使用该代码,其中包含大约20000-52000个时间步长和3200个特征

#Base code: https://jasdeep06.github.io/posts/Understanding-LSTM-in-Tensorflow-MNIST/

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)



#define constants
#unrolled through 28 time steps
time_steps=28

#hidden LSTM units
num_units=128
num_layers = 4

#rows of 28 pixels
n_input=28

#learning rate for adam
learning_rate=0.001

#mnist is meant to be classified in 10 classes(0-9).
n_classes=10

#size of batch
batch_size=128

num_steps = 14


#weights and biases of appropriate shape to accomplish above task
out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))
out_bias=tf.Variable(tf.random_normal([n_classes]))

#defining placeholders
#input image placeholder
x=tf.placeholder("float",[None,num_steps,n_input])
#input label placeholder
y=tf.placeholder("float",[None,n_classes])

seqlen = tf.placeholder(tf.int32,[None])

#processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors
input=tf.unstack(x ,num_steps,1)

#State Tuple
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, num_units])

l = tf.unstack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
    [rnn.LSTMStateTuple(l[idx][0], l[idx][1])
     for idx in range(num_layers)]
)

def lstm():
    lstm = rnn.LSTMCell(num_units, state_is_tuple=True)
    return lstm

cell = tf.nn.rnn_cell.MultiRNNCell([lstm() for _ in range(num_layers)], state_is_tuple=True)

outputs, state = rnn.static_rnn(cell, input,initial_state=rnn_tuple_state,dtype="float32")

#converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication
prediction=tf.matmul(outputs[-1],out_weights)+out_bias

#loss_function
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#optimization
opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

#model evaluation
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#initialize variables
init=tf.global_variables_initializer()

final_state = np.zeros((num_layers, 2, batch_size, num_units))
with tf.Session() as sess:

    sess.run(init)
    iter=0
    while iter<2000: #Training Steps
        batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

        batch_x=batch_x.reshape((batch_size,time_steps,n_input))

        for i in range(0,int(time_steps/num_steps)): #Iterating through Timesteps for TBTT
            down = (num_steps*i)
            up = down + num_steps

            slen = []
            for sl in range(0,batch_size):
                slen.append(up-down)
            slen = np.asarray(slen,dtype="int")

            ostate,_ = sess.run([state,opt], feed_dict={state_placeholder: final_state,x: batch_x[:,down:up,:], y: batch_y,seqlen:slen})

            if iter %10==0:
                acc=sess.run(accuracy,feed_dict={state_placeholder: final_state,x:batch_x[:,down:up,:],y:batch_y,seqlen:slen})
                los=sess.run(loss,feed_dict={state_placeholder: final_state,x:batch_x[:,down:up,:],y:batch_y,seqlen:slen})
                print("For iter ",iter)
                print("Accuracy ",acc)
                print("Loss ",los)
                print("__________________")

        iter=iter+1
        final_state = np.array(ostate)

    #calculating test accuracy
    test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))
    test_label = mnist.test.labels[:128]
    print("Testing Accuracy:")
    for i in range(0,int(time_steps/num_steps)):
        down = (num_steps*i)
        up = down + num_steps
        ostate,acc = sess.run([state,accuracy], feed_dict={state_placeholder: final_state,x: test_data[:,down:up,:], y: test_label})
        print(acc)

0 个答案:

没有答案