Tensorflow LSTM内存使用情况

时间:2017-02-13 12:12:37

标签: tensorflow lstm

我正在尝试LSTM序列标记。在添加嵌入层之前,我将输入数据直接放到LSTM层。但即使批量大小为1,它也会给我带来GPU内存错误。

max_length是330,我是否需要更改模型或添加嵌入层才有效?我使用Titan X GPU和12 GB RAM。

# tf Graph input
x = tf.placeholder(tf.float32, [None, max_length, num_table])
y = tf.placeholder(tf.float32, [None, max_length, n_classes])
seqlen = tf.placeholder(tf.int32,[None])

# Define weights
weights = {
    'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}

def LSTM(x, seqlen, weights, biases):
    # given input: (batch_size, n_step, feature_table )
    # required : (n_step, batch_size, feature_table )
    x = tf.unpack(tf.transpose(x,perm=[1,0,2]))

    lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden)
    #lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell,keep_prob)

    outputs, states = tf.nn.rnn(cell=lstm_cell,
        dtype=tf.float32,
        sequence_length=seqlen,
        inputs=x)

    # convert to (n_step, batch_size, n_classes)    
    temp = [tf.matmul(output,weights['out']) + biases['out'] for output in outputs]

    # convert to (batch_size, n_step, n_classes)
    temp = tf.transpose(tf.pack(temp),perm = [1,0,2])
    return temp

0 个答案:

没有答案