这些南方来自何处?

时间:2016-04-04 01:38:50

标签: tensorflow

编辑:
我的数据中有一些纳米但是anwser是正确的,你必须用一些噪音初始化你的体重! 谢谢!

我用tensorflow做我的第一个脚本。我有一些打印价值的问题,但现在我明白了。 我想尝试一个简单的Logistic回归开始,我正在研究kaggle titanic数据集。

我的问题是我不知道为什么但是我在我的体重和偏见中得到了一些楠,所以在我的y(预测)矢量中......

编辑: 我的体重初始化为0,所以我猜是一个零梯度。 根据我提供的答案添加

W = tf.truncated_normal([5, 1], stddev=0.1)

而不是

 W = tf.Variable(tf.zeros([5, 1])) #weight for softmax

但我仍有一些问题。我的b变量和y变量仍然是nan,当我为b尝试相同的事情时,我得到以下错误: ValueError:没有要优化的变量
我尝试了几种方法来分配我的张量[1,1],但看起来我错过了一些东西
看起来y是nan,因为交叉熵是nan,因为b是nan ... :( 结束 - 编辑

我读了这篇帖子(Why does TensorFlow return [[nan nan]] instead of probabilities from a CSV file?)谁给了我一个提示,在我的交叉熵计算0 * log(0)返回nan所以我应用给出的解决方案,即添加1e-50像: / p>

cross_entropy = -tf.reduce_sum(y_ * tf.log(y + 1e-50))

不幸的是,这不是我想的问题,我到处都是楠:(

这是我非常简单的模型的内插(我猜)部分:

x = tf.placeholder(tf.float32, [None, 5]) #placeholder for input data

W = tf.truncated_normal([5, 1], stddev=0.1)

b = tf.Variable(tf.zeros([1])) # no error but nan
#b = tf.truncated_normal([1, 1], stddev=0.1) Thow the error descript above
#b = [0.1] no error but nan

y = tf.nn.softmax(tf.matmul(x, W) + b) #our model -> pred from model

y_ = tf.placeholder(tf.float32, [None, 1])#placeholder for input 

cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # crossentropy cost function

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables() # create variable

sess = tf.InteractiveSession()
sess.run(init)

testacc = []
trainacc = []
for i in range(15):
    batch_xs = train_input[i*50:(i + 1) * 50]
    batch_ys = train_label[i*50:(i + 1) * 50]

    result = sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(y,y_)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run([accuracy, W, y] , feed_dict={x: test_input, y_: test_label}))

当然,在2个数组的纳米之后,它返回0.0精度 我试图在任何地方打印价值,但到处都是纳:'(

有人有想法吗?我可能会忘记某些事情或做错了

问题是我尝试了一个类似的脚本与mnist(谷歌教程)包含数据和它的作品(没有南)。我通过读取csv文件的熊猫获取我的数据。

感谢阅读!

1 个答案:

答案 0 :(得分:2)

由于你的权重矩阵为零,你在tf.nn.softmax得到零除。使用不同的规范化方法,例如MNIST示例中的truncated_normal