无法获得简单的二进制分类器

时间:2017-08-05 18:37:47

标签: machine-learning tensorflow

我使用TensorFlow编写了一个简单的二元分类器。但我得到的优化变量的唯一结果是NaN。这是代码:

import tensorflow as tf

# Input values
x = tf.range(0., 40.)
y = tf.constant([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
                 1., 0., 0., 1., 0., 1., 0., 1., 1., 1.,
                 1., 1., 0., 1., 1., 1., 0., 1., 1., 1.,
                 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.])

# Variables
m = tf.Variable(tf.random_normal([]))
b = tf.Variable(tf.random_normal([]))

# Model and cost
model = tf.nn.sigmoid(tf.add(tf.multiply(x, m), b))
cost = -1. * tf.reduce_sum(y * tf.log(model) + (1. - y) * (1. - tf.log(model)))

# Optimizer
learn_rate = 0.05
num_epochs = 20000
optimizer = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost)

# Initialize variables
init = tf.global_variables_initializer()

# Launch session
with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(num_epochs):
        sess.run(optimizer)

    # Display results
    print("m =", sess.run(m))
    print("b =", sess.run(b))

我尝试了不同的优化器,学习速率和测试大小。但似乎没有任何效果。有什么想法吗?

1 个答案:

答案 0 :(得分:1)

您使用标准偏差1初始化mb,但对于您的数据xy,您可以预期m明显小于b 1.您可以将m初始化为零(这对于偏差项而言非常流行)和import tensorflow as tf import matplotlib.pyplot as plt # Input values x = tf.range(0., 40.) y = tf.constant([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.]) # Variables m = tf.Variable(tf.random_normal([], mean=0.0, stddev=0.0005)) b = tf.Variable(tf.zeros([])) # Model and cost model = tf.nn.sigmoid(tf.add(tf.multiply(x, m), b)) cost = -1. * tf.reduce_sum(y * tf.log(model) + (1. - y) * (1. - tf.log(model))) # Optimizer learn_rate = 0.00000005 num_epochs = 20000 optimizer = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) # Initialize variables init = tf.global_variables_initializer() # Launch session with tf.Session() as sess: sess.run(init) # Fit all training data for epoch in range(num_epochs): _, xs, ys = sess.run([optimizer, x, y]) ms = sess.run(m) bs = sess.run(b) print(ms, bs) plt.plot(xs,ys) plt.plot(xs, ms * xs + bs) plt.savefig('tf_test.png') plt.show() plt.clf() 以及更小的标准差(例如0.0005)并同时降低学习率(例如至0.00000005)。你可以延迟NaN值来改变这些值,但它们最终可能会发生,因为在我看来你的数据并没有被线性函数很好地描述。 plot

public abstract class Base<T> where T : Base<T>
{
  protected static Environment Environment { get; private set; }

  public static void Init(Environment newEnvironment)
  {
    Environment = newEnvironment;
  }
}