动量和(小型)批量培训是否兼容?

时间:2017-06-16 12:33:41

标签: neural-network backpropagation

我有一个非常好的反向传播模型,但我想实施批量训练。

批量训练前的代码(在反向传播函数中),伪代码:

forevery(connection in this.connections.in){
  // Adjust weight
  var deltaWeight = rate * gradient + momentum * connection.previousDeltaWeight;
  connection.weight += deltaWeight;
  connection.previousDeltaWeight = deltaWeight;
}

// Adjust bias
var deltaBias = rate * this.error.responsibility + momentum * this.previousDeltaBias;
this.bias += deltaBias;

this.previousDeltabias = deltaBias;

新代码是:

forevery(connection in this.connections.in){
  // Adjust weight
  var deltaWeight = rate * gradient * this.mask + momentum * connection.previousDeltaWeight;
  connection.totalDeltaWeight += deltaWeight;
  if(update){
    connection.weight += connection.totalDeltaWeight;
    connection.previousDeltaWeight = connection.totalDeltaWeight;
    connection.totalDeltaWeight = 0;
  }
}

// Adjust bias
var deltaBias = rate * this.error.responsibility + momentum * this.previousDeltaBias;
this.totalDeltaBias += deltaBias;
if(update){
  this.bias += this.totalDeltaBias;
  this.previousDeltaBias = this.totalDeltaBias;
  this.totalDeltaBias = 0;
}

因此,如果批量大小为4,则反向传播将使用update=false调用3x,使用update=true调用第4次。批量训练工作正常,但当我打开动量(=0.9)时,所有值都开始溢出。可能是什么问题?

1 个答案:

答案 0 :(得分:0)

哇。我正在积累错误的势头。我把它batch_size次包括在内,这是错误的,所以现在只包括一次。

forevery(connection in this.connections.in){
  // Adjust weight
  var deltaWeight = rate * gradient * this.mask;
  connection.totalDeltaWeight += deltaWeight;
  if(update){
    connection.totalDeltaWeight += momentum * connection.previousDeltaWeight;
    connection.weight += connection.totalDeltaWeight;
    connection.previousDeltaWeight = connection.totalDeltaWeight;
    connection.totalDeltaWeight = 0;
  }
}

// note: MINI_BATCH SHALL BE OPTIMIZED SOON

// Adjust bias
var deltaBias = rate * this.error.responsibility;
this.totalDeltaBias += deltaBias;
if(update){
  this.totalDeltaBias += momentum * this.previousDeltaBias;
  this.bias += this.totalDeltaBias;
  this.previousDeltaBias = this.totalDeltaBias;
  this.totalDeltaBias = 0;
}