默认情况下,在tensorflow中进行在线或批量培训

时间:2017-10-20 00:00:40

标签: python optimization tensorflow neural-network training-data

我有以下问题:我正在尝试学习张量流,但我仍然没有找到将培训设置为在线或批量的位置。例如,如果我有以下代码来训练神经网络:

loss_op = tf.reduce_mean(tf.pow(neural_net(X) - Y, 2))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)


sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})

如果我同时提供所有数据(即batch_x包含所有数据),这是否意味着培训是批量培训?或者张量流优化器以不同的方式优化?如果我执行 for 循环,一次只提供一个数据样本,是不是错了?这算是单步(在线)培训吗?谢谢你的帮助。

2 个答案:

答案 0 :(得分:7)

主要有三种渐变下降类型。具体地,

  1. 随机梯度下降
  2. 批量渐变下降
  3. Mini Batch Gradient Descent
  4. 这里是一个很好的教程(https://machinelearningmastery.com/gentle-introduction-mini-batch-gradient-descent-configure-batch-size/),对上面三种方法有好处和缺点。

    对于您的问题,以下是标准样本训练张量流代码,

    N_EPOCHS = #Need to define here
    BATCH_SIZE = # Need to define hare
    
    with tf.Session() as sess:
       train_count = len(train_x)    
    
        for i in range(1, N_EPOCHS + 1):
            for start, end in zip(range(0, train_count, BATCH_SIZE),
                                  range(BATCH_SIZE, train_count + 1,BATCH_SIZE)):
    
                sess.run(train_op, feed_dict={X: train_x[start:end],
                                               Y: train_y[start:end]})
    

    此处N_EPOCHS表示整个训练数据集的传球次数。您可以根据Gradient Descent方法设置BATCH_SIZE。

    • 随机梯度下降,BATCH_SIZE = 1.
    • 对于批量渐变下降,BATCH_SIZE =训练数据集大小。

    • 对于 Mini Batch Gradient Decent ,1<< BATCH_SIZE<<训练数据集大小。

    在三种方法中,最流行的方法是 Mini Batch Gradient Decent 。但是,您需要根据您的要求设置BATCH_SIZE参数。 BATCH_SIZE的良好默认值可能是32。

    希望这有帮助。

答案 1 :(得分:1)

通常,Tensorflow中数据占位符的第一维设置为batch_size,TensorFlow默认不定义(培训策略)。您可以设置第一个维度以确定它是否在线(第一个维度为1)或小批量(通常为十个)。例如:

@Entity
@Table(name = "employee")
@Where(clause = "isDeleted = false")
public class Employee extends BaseEntity {

    @OneToOne
    private Department department;

    private int status;

    public Department getDepartment() {
        return department;
    }

    public void setDepartment(Department department) {
        this.department = department;
    }

    public int getStatus() {
        return status;
    }

    public void setStatus(int status) {
        this.status = status;
    }


}