构建和训练RNN

时间:2019-08-08 12:39:14

标签: python tensorflow optimization recurrent-neural-network

我在tensorflow中构建了一个非标准的RNN,当我尝试对其进行训练时,它会使用过多的内存。仅构建网络就使用1GB,训练时的内存使用量高达5GB。这也很慢。

网络使用200个浮点数作为其内部状态,但是在每个步骤中,只有100个浮点数用作网络的输入(选择这些浮点数的规则不是网络的一部分),因此我通过以下方式对此进行建模从张量中获取100个值,然后使用tf.stack重新组合它们。然后,使用另一个tf.stack将网络的输出与未用作输入的100个值进行组合。

import tensorflow as tf
from random import sample, random


def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, dtype=tf.float32)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


def run():
    sess = tf.Session()
    W0 = weight_variable((100, 200))
    B0 = bias_variable([200])
    W1 = weight_variable((200, 200))
    B1 = bias_variable([200])
    W2 = weight_variable((200, 100))
    B2 = bias_variable([100])

    memory = tf.constant(0, tf.float32, (200,))
    outputs = []
    correct_outputs = []

    for i in range(40):
        indexes = sample(range(200), 100)
        memory_selection = [memory[z] for z in indexes]
        S = tf.stack(memory_selection, axis=0)
        S = tf.stack([S], axis=0)
        Input = tf.nn.relu(tf.matmul(S, W0) + B0)
        Hidden = tf.nn.relu(tf.matmul(Input, W1) + B1)
        Output = tf.nn.relu(tf.matmul(Hidden, W2) + B2)
        memory_output = []
        used = 0
        for z in range(200):
            if z in indexes:
                memory_output.append(Output[0][used])
                used += 1
            else:
                memory_output.append(memory[z])
        memory = tf.stack(memory_output, axis=0)
        outputs.append(Output)
        correct_outputs.append([random() for _ in range(100)])

    print("Network Built")

    outputs = tf.stack(outputs, axis=0)
    correct_outputs = tf.constant(correct_outputs)
    loss = tf.nn.l2_loss(outputs-correct_outputs)
    optimiser = tf.contrib.optimizer_v2.AdamOptimizer(0.0001)
    train = optimiser.minimize(loss)
    sess.run(tf.global_variables_initializer())
    print("Variables Initialized")
    sess.run(train)
    print("Trained")


run()

我最终获得了40步长的展开网络。 堆栈和堆栈可能引起内存问题吗? 大部分时间都花在gradients_impl.py和渐变中。

编辑:我已经在PyTorch中重写了上面的代码,与张量流所花费的时间相比,它的运行速度更快,不到一秒钟。

0 个答案:

没有答案