有关如何在TensorFlow

时间:2017-06-04 18:51:09

标签: model tensorflow

我知道我们可以编写自定义模型并使用tf.estimator封装它。但我似乎无法找到任何带有示例的文档。

我知道你必须在'model_fn'中定义你的模型,但我应该从这个函数返回什么。我也应该把损失和训练步骤放在'model_fn'或仅仅是网络中。我应该如何修改下面给出的代码,使其与tf.estimator一起使用。真的很感激一些帮助。

def test_model(features,labels):
    X = tf.placeholder(tf.float32,shape=(None,1),name="Data_Input")
    #Output
    Y = tf.placeholder(tf.float32,shape=(None,1),name="Target_Labels")
    W =  tf.Variable(tf.random_normal([0],stddev=stddev0)) 
    b = tf.Variable(tf.random_normal([0],stddev=stddev0))

    Ypredict = W*X + b
    return Ypredict

 estimator = tf.estimator.Estimator(model_fn = test_model)

1 个答案:

答案 0 :(得分:1)

您应该返回tf.estimator.EstimatorSpec个对象。有什么影响:

def model_fn(features, labels, mode, params):
    /*
    Your marvelous model
    */
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels_onehot, logits=logits)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())  
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

还有更多内容,因此要获得更好的演练,请参阅here