TensorFlow:DNNRegressor.fit()中的参数错误无效

时间:2016-12-28 02:45:25

标签: python machine-learning tensorflow regression data-science

我正在尝试按照Deep Neural Network Regression with Boston Data的示例对我自己的数据进行回归。

以下是我的代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from sklearn import cross_validation
from sklearn import metrics
from sklearn import preprocessing
import tensorflow as tf
from tensorflow.contrib import learn
from numpy import genfromtxt


def main():
  x_test = genfromtxt('ARM4mDec2002Jul2015OklahomaV2_mar_apr_may_date_time_normalized_16000_test_data_x.csv', delimiter=',')
  y_test = genfromtxt('ARM4mDec2002Jul2015OklahomaV2_mar_apr_may_date_time_normalized_16000_test_data_y.csv', delimiter=',')

  x_train = genfromtxt('ARM4mDec2002Jul2015OklahomaV2_mar_apr_may_date_time_normalized_16000_training_data_x.csv', delimiter=',')
  y_train = genfromtxt('ARM4mDec2002Jul2015OklahomaV2_mar_apr_may_date_time_normalized_16000_training_data_y.csv', delimiter=',')

  # Build 2 layer fully connected DNN with 10, 10 units respectively.
  feature_columns = learn.infer_real_valued_columns_from_input(x_train)
  regressor = learn.DNNRegressor(
      feature_columns=feature_columns, hidden_units=[10, 10])

  # Fit
  regressor.fit(x_train, y_train, steps=10, batch_size=1)

  # Predict and score
  y_predicted = list(
      regressor.predict(scaler.transform(x_test), as_iterable=True))
  score = metrics.mean_squared_error(y_predicted, y_test)

  print('MSE: {0:f}'.format(score))


if __name__ == '__main__':
  tf.app.run()

我收到很多警告和无效的参数错误。完整的控制台输出可用here。我想在下面列出错误消息的亮点。

W tensorflow/core/framework/op_kernel.cc:975] Invalid argument: Nan in summary histogram for: dnn/hiddenlayer_0_activation
     [[Node: dnn/hiddenlayer_0_activation = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](dnn/hiddenlayer_0_activation/tag, dnn/hiddenlayer_0/hiddenlayer_0/Relu)]]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "arm_data_regression.py", line 39, in main
    regressor.fit(x_train, y_train, steps=10, batch_size=1)
  File "/home/shehab/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 191, in new_func
    return func(*args, **kwargs)

这是因为我使用了一些弃用的TensorFlow API吗?

0 个答案:

没有答案