从张量流估计器中获取超参数

时间:2018-06-01 19:43:44

标签: tensorflow linear-regression hyperparameters

此问题之前一直没有运气

Video

我在Stack Overflow线程上尝试了推荐的选项,但没有一个工作

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# Declare list of features, we only have one real-valued feature
def model_fn(features, labels, mode):
  # Build a linear model and predict values
  W = tf.get_variable("W", [1], dtype=tf.float64)
  b = tf.get_variable("b", [1], dtype=tf.float64)
  y = W * features['x'] + b
  # Loss sub-graph
  loss = tf.reduce_sum(tf.square(y - labels))
  # Training sub-graph
  global_step = tf.train.get_global_step()
  optimizer = tf.train.GradientDescentOptimizer(0.01)
  train = tf.group(optimizer.minimize(loss), tf.assign_add(global_step, 1))
  # EstimatorSpec connects subgraphs we built to the
  # appropriate functionality.
  return tf.estimator.EstimatorSpec(mode=mode, predictions=y, loss=loss,  train_op=train)


Model_variables = tf.GraphKeys.MODEL_VARIABLES
Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES


estimator = tf.estimator.Estimator(model_fn=model_fn)
# define our data sets
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])
input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)
train_input_fn = tf.estimator.inputs.numpy_input_fn({"x": x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)
eval_input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)

# train
estimator.train(input_fn=input_fn, steps=1000)
# Here we evaluate how well our model did.
train_metrics = estimator.evaluate(input_fn=train_input_fn)
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("train metrics: %r"% train_metrics)
print("eval metrics: %r"% eval_metrics)


model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

print (len(model_vars))

all_vars = tf.get_collection(Model_variables)
# print (all_vars)
for i in all_vars:
    print (str(i) + '  -->  '+ str(i.eval()))

0 个答案:

没有答案