张量流线性回归误差爆炸

时间:2017-05-28 04:15:14

标签: python machine-learning tensorflow linear-regression gradient-descent

我试图使用tensorflow拟合一个非常简单的线性回归模型。但是,损失(均方误差)会爆炸而不是降低到零。

首先,我生成我的数据:

x_data = np.random.uniform(high=10,low=0,size=100)
y_data = 3.5 * x_data -4 + np.random.normal(loc=0, scale=2,size=100)

然后,我定义了计算图:

X = tf.placeholder(dtype=tf.float32, shape=100)
Y = tf.placeholder(dtype=tf.float32, shape=100)
m = tf.Variable(1.0)
c = tf.Variable(1.0)
Ypred = m*X + c
loss = tf.reduce_mean(tf.square(Ypred - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
train = optimizer.minimize(loss)

最后,运行100个时期:

steps = {}
steps['m'] = []
steps['c'] = []

losses=[]

for k in range(100):
    _m = session.run(m)
    _c = session.run(c)
    _l = session.run(loss, feed_dict={X: x_data, Y:y_data})
    session.run(train, feed_dict={X: x_data, Y:y_data})
    steps['m'].append(_m)
    steps['c'].append(_c)
    losses.append(_l)

然而,当我策划损失时,我得到:

enter image description here

也可以找到完整的代码here

2 个答案:

答案 0 :(得分:3)

每当您看到您的费用随着时代数量单调增加时,这是一个肯定标志,表明您的学习率过高。反复重新训练你的学习率,每次乘以1/10,直到成本函数明显随着时期的数量减少。

答案 1 :(得分:2)

学习率太高; 0.001效果很好:

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

public class CrossAndZeros {
    private static CellState winner;
    private static Enum[][] field = new Enum[3][3];

    public static void main(String[] args) throws IOException {

        for (int i = 0; i < field.length; i++) {
            for (int j = 0; j < field[i].length; j++) {
                field[i][j] = CellState.values()[new Random().nextInt(3)];
            }
        }

        for (Enum[] enums : field) {
            System.out.println(Arrays.toString(enums));
        }
        System.out.println();

        System.out.println("Winner is found: " + isWinnerFound());

        System.out.println(winner == null ? "No winner, GAME OVER" : winner);
    }

    private static boolean isWinnerFound() {
        int[] result = calculate();
        int count = 0;

        for (int win : result) {
            if (win == 3) {
                winner = CellState.OCCUPIED_BY_X;
                return true;
            } else if (win == -12) {
                winner = CellState.OCCUPIED_BY_O;
                return true;
            } else if (win == -9 || win == -2 || win == -3) { // This means that the line is spoilt
                count++;
            }
        }
        return count == 8; // If all the lines are spoilt, the game is over
    }

    private static int[] calculate() {
        int[] result = new int[8];

        for (int i = 0; i < field.length; i++) {
            for (int j = 0; j < field[i].length; j++) {
            result[i] += getCellOwner(field[j][i]); // a column
            result[i + 3] += getCellOwner(field[i][j]); // a row
        }
        result[field.length * 2] += getCellOwner(field[i][i]); // diagonal
        result[field.length * 2 + 1] += getCellOwner(field[i][field.length - i - 1]); // diagonal

        }

        System.out.println(Arrays.toString(result));

        return result;
    }

    private static int getCellOwner(Enum cell) {
        switch ((CellState) cell) {
            case OCCUPIED_BY_O:
                return -4;
            case OCCUPIED_BY_X:
                return 1;
            case EMPTY:
            default:
                return 0;
        }
    }

    public enum CellState {
        /**
         * this cell is occupied by player X
         */
        OCCUPIED_BY_X,

        /**
         * this cell is occupied by player O
         */
        OCCUPIED_BY_O,

        /**
         * this cell is Empty
         */
        EMPTY
    }
}

Plot of losses[![][1]

(潜在有用的参考:https://gist.github.com/fuglede/ad04ce38e80887ddcbeb6b81e97bbfbc