Tensorflow:Python中的有序Probit模型应用程序

时间:2016-11-09 10:59:49

标签: python python-2.7 tensorflow

我尝试使用Tensorflow构建有序Probit机器学习算法。虽然我的实际应用程序相当复杂,但我重写了我的代码更为通用。它显示在这篇文章中。

我无法比维基百科更好地解释有序Probit:https://en.wikipedia.org/wiki/Ordered_probit

然而,代码的问题非常简单。我试图优化的函数包含一些条件,我强烈怀疑它是在其中一个语句中

  

ValueError :(形状(100,)和()不兼容)

生成。我在Tensorflow文档的这个页面上阅读了所有有关形状的内容:https://www.tensorflow.org/versions/r0.11/resources/dims_types.html我尝试使用变量的形状,但我无法使其工作。

提前感谢您花时间试图帮助我。非常感谢!

import numpy as np
import tensorflow as tf


# Tensorflow application for an Ordered Probit Model
# https://en.wikipedia.org/wiki/Ordered_probit

# Create 100 phony x, y_star data points in NumPy, y_star = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_star = x_data * 0.5 + 0.3

# Set true values for mu1 and mu2
mu1 = 0.45
mu2 = 0.60

# Build the 'true' y-data (the observed y values in the ordered probit model)
y_data = np.zeros(100).astype(np.float32)
for i, val in enumerate(y_star):
    if val < mu1:
        y_data[i] = 0.0
    elif val > mu2:
        y_data[i] = 1.0
    else:
        y_data[i] = 0.5

# Initialize all variables that should be estimated
W = tf.Variable(tf.zeros([1]))
b = tf.Variable(tf.zeros([1]))
mu1 = tf.Variable(tf.zeros([1]))
mu2 = tf.Variable(tf.zeros([1]))


# The actual model, it is an Ordered Probit Model
# https://en.wikipedia.org/wiki/Ordered_probit
is_less_than_mu1 = tf.less(W * x_data + b, mu1, name="is_less_than_mu1")
is_more_than_mu2 = tf.greater(W * x_data + b, mu2, name="is_more_than_mu2")

pred = tf.cond(is_less_than_mu1,
               lambda: 0.0,
               tf.cond(is_more_than_mu2,
                       lambda: 1.0,
                       lambda: 0.5
                       ))

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(pred - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.initialize_all_variables()

# Launch the graph.
sess = tf.Session()
sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b), sess.run(mu1), sess.run(mu2))

# Should learn best fit is W: [0.1], b: [0.3], mu1: [0.45], mu2: [0.60]

0 个答案:

没有答案