Tensorflow LinearRegressor评估方法挂起

时间:2017-03-17 08:10:45

标签: tensorflow

考虑以下玩具TensorFlow代码。 LinearRegressor fit 方法正常工作并找到正确的系数(即y = x1 + x2),但计算(参见最后一个打印语句) )挂起。知道什么是错的吗?

import tensorflow as tf

x1 = [1, 3, 4, 5, 1, 6, -1, -3]
x2 = [5, 2, 1, 5, 0, 2, 4, 2]
y = [6, 5,5, 10, 1, 8, 3, -1]

def train_fn():
  return {'x1': tf.constant(x1), 'x2':tf.constant(x2)}, tf.constant(y)


features = [tf.contrib.layers.real_valued_column('x1', dimension=1),
            tf.contrib.layers.real_valued_column('x2', dimension=1)]
estimator = tf.contrib.learn.LinearRegressor(feature_columns=features)
estimator.fit(input_fn=train_fn, steps=10000)

for vn in estimator.get_variable_names():
  print('variable name', vn, estimator.get_variable_value(vn))
print(estimator.evaluate(input_fn=train_fn))

1 个答案:

答案 0 :(得分:4)

estimator.evaluate()采用参数steps,默认为None,其被解释为“无穷大”。因此它永远不会结束。

要结束,请明确传递steps=1

estimator.evaluate(input_fn=your_input_fn, steps=1)