我无法理解为什么这不起作用,基本上我正在尝试检索m和q的值只是为了打印它们但我总是得到[nan,nan]
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
m = tf.Variable(tf.random_uniform(shape=()), dtype=tf.float32)
q = tf.Variable(tf.random_uniform(shape=()), dtype=tf.float32)
X = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
linear_model = m * X + q
cost = tf.reduce_sum(tf.square(linear_model - y))
optimizer = tf.train.GradientDescentOptimizer(0.01)
gdescent = optimizer.minimize(cost)
train, test = [pd.read_csv(file) for file in ["train.csv", "test.csv"]]
if False: # Scatter training points
plt.scatter(train['x'], train['y'], s=1)
plt.show()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
_, m, q = sess.run([gdescent, m, q], feed_dict={X: train['x'].values, y: train['y'].values})
print(m, q)
train.csv和test.csv都是带有标题行和两列值x和y的文件
前10行train.csv
x,y
24,21.54945196
50,47.46446305
15,17.21865634
38,36.58639803
87,87.28898389
36,32.46387493
12,10.78089683
81,80.7633986
25,24.61215147
答案 0 :(得分:0)
解决了,错误是train.csv错过了第215行的y值