我试图使用tensorflow拟合一个非常简单的线性回归模型。但是,损失(均方误差)会爆炸而不是降低到零。
首先,我生成我的数据:
x_data = np.random.uniform(high=10,low=0,size=100)
y_data = 3.5 * x_data -4 + np.random.normal(loc=0, scale=2,size=100)
然后,我定义了计算图:
X = tf.placeholder(dtype=tf.float32, shape=100)
Y = tf.placeholder(dtype=tf.float32, shape=100)
m = tf.Variable(1.0)
c = tf.Variable(1.0)
Ypred = m*X + c
loss = tf.reduce_mean(tf.square(Ypred - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
train = optimizer.minimize(loss)
最后,运行100个时期:
steps = {}
steps['m'] = []
steps['c'] = []
losses=[]
for k in range(100):
_m = session.run(m)
_c = session.run(c)
_l = session.run(loss, feed_dict={X: x_data, Y:y_data})
session.run(train, feed_dict={X: x_data, Y:y_data})
steps['m'].append(_m)
steps['c'].append(_c)
losses.append(_l)
然而,当我策划损失时,我得到:
也可以找到完整的代码here。
答案 0 :(得分:3)
每当您看到您的费用随着时代数量单调增加时,这是一个肯定标志,表明您的学习率过高。反复重新训练你的学习率,每次乘以1/10,直到成本函数明显随着时期的数量减少。
答案 1 :(得分:2)
学习率太高; 0.001效果很好:
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
public class CrossAndZeros {
private static CellState winner;
private static Enum[][] field = new Enum[3][3];
public static void main(String[] args) throws IOException {
for (int i = 0; i < field.length; i++) {
for (int j = 0; j < field[i].length; j++) {
field[i][j] = CellState.values()[new Random().nextInt(3)];
}
}
for (Enum[] enums : field) {
System.out.println(Arrays.toString(enums));
}
System.out.println();
System.out.println("Winner is found: " + isWinnerFound());
System.out.println(winner == null ? "No winner, GAME OVER" : winner);
}
private static boolean isWinnerFound() {
int[] result = calculate();
int count = 0;
for (int win : result) {
if (win == 3) {
winner = CellState.OCCUPIED_BY_X;
return true;
} else if (win == -12) {
winner = CellState.OCCUPIED_BY_O;
return true;
} else if (win == -9 || win == -2 || win == -3) { // This means that the line is spoilt
count++;
}
}
return count == 8; // If all the lines are spoilt, the game is over
}
private static int[] calculate() {
int[] result = new int[8];
for (int i = 0; i < field.length; i++) {
for (int j = 0; j < field[i].length; j++) {
result[i] += getCellOwner(field[j][i]); // a column
result[i + 3] += getCellOwner(field[i][j]); // a row
}
result[field.length * 2] += getCellOwner(field[i][i]); // diagonal
result[field.length * 2 + 1] += getCellOwner(field[i][field.length - i - 1]); // diagonal
}
System.out.println(Arrays.toString(result));
return result;
}
private static int getCellOwner(Enum cell) {
switch ((CellState) cell) {
case OCCUPIED_BY_O:
return -4;
case OCCUPIED_BY_X:
return 1;
case EMPTY:
default:
return 0;
}
}
public enum CellState {
/**
* this cell is occupied by player X
*/
OCCUPIED_BY_X,
/**
* this cell is occupied by player O
*/
OCCUPIED_BY_O,
/**
* this cell is Empty
*/
EMPTY
}
}
(潜在有用的参考:https://gist.github.com/fuglede/ad04ce38e80887ddcbeb6b81e97bbfbc)