使用Tensorflow.js的NewGeneration方法出现内存泄漏问题

时间:2019-07-05 16:48:12

标签: javascript tensorflow.js

我正在使用Tensorflow.js开发JavaScript库,以通过浏览器中的NeuroEvolution训练模型。在我的主类中,我有一个NewGeneration方法,并且我认为它每次都会造成160Tensors的内存泄漏。我在某个地方搞砸了吗?

我在newGen函数,pickOne函数以及它调用的所有函数上添加了tf.tidy()包装,但是没有任何效果。 如果您需要查看更多代码,请告诉我。

这是newGeneration的功能(保存的代理是失败的代理,而代理是将来的代理)

newGen(){
    for(let i = 0; i < this.popSize; i++)
        this.agents.push(this.pickOne(this.savedAgents))

     for(let agent of this.savedAgents)
        tf.dispose(agent.brain.model)

     this.savedAgents = []
     print(tf.memory())
 }

这是pickOne函数

pickOne(oldAgents){
    getTheChildsIndex() //Fake function but no code related to the problem here
    let child = oldAgents[index];
    let agent = new Agent() //Agent class holds a brain, in which is a tf.sequential model
                            //And a body, which has nothing to do with tf

    tf.tidy(()=>{
        agent.brain = child.brain.copy()
        agent.brain.mutate() //Mutate has proved to have another leak, but now it's fixed
    })
    return agent;
}

正如我被问到的,这里是为了以防万一的变异代码

mutate(mutationRate){
    tf.tidy(()=>{   
        let mutatedWeights
        const weights = this.model.getWeights()
        mutatedWeights = []
        for(let i = 0; i < weights.length; i++){
            let tensor = weights[i]
            let shape = tensor.shape
            let values = tensor.dataSync().slice()
            for(let j = 0; j < values.length; j++){
                if(random(1) < mutationRate)
                    values[j] = randomGaussian()
            }
            let newTensor = tf.tensor(values, shape)
            mutatedWeights.push(newTensor)
        }
        this.model.setWeights(mutatedWeights)
    })
    return TheBrain
}

以及复制功能

copy(){
    let modelCopy
    tf.tidy(()=>{
        modelCopy = new NeuralNetwork() // A helper class thatholds the model
        const weights  = this.model.getWeights()
        let clonedWeights = []
        for(let i = 0; i < weights.length; i++)
            clonedWeights.push(weights[i].clone())
        modelCopy.model.setWeights(clonedWeights)    
    })
    return modelCopy
}

我希望该函数不会产生泄漏,因此会销毁它创建的所有张量。

1 个答案:

答案 0 :(得分:0)

我终于自己解决了。创建代理程序时,无需使用pickOne方法启动大脑。