使用tf.estimator.Estimator

时间:2017-08-01 22:26:04

标签: tensorflow

我一直在努力学习最近从contrib转移到主API的层和估算器框架。我遇到了一个相当奇怪的问题。我为MNIST编写了一个简单的自动编码器,但不知何故,当我训练它时,即使损失值正在减小,我仍然说我在第0步,所以我猜模型正在接受训练。当然,由于它不计算步数,因此不会保存检查点,也不会保存任何摘要。不确定我做错了什么,并且所有文档都指向旧的“tf.contrib.learn”框架,并且许多API似乎被标记为已弃用。我该如何工作?这是我的代码:

def encoder(x):
    l1 = tf.layers.dense(x, 256, activation=tf.nn.relu, name='encode1')
    l2 = tf.layers.dense(l1, 128, activation=tf.nn.relu, name='encode2')
    return l2

def decoder(x):
    l1 = tf.layers.dense(x, 256, activation=tf.nn.relu, name='decode1')
    l2 = tf.layers.dense(l1, 784, activation=tf.nn.relu, name='decode2')
    return l2

def loss(labels, preds):
    return tf.losses.huber_loss(labels, preds)

def train(loss):
    optimizer = tf.train.AdamOptimizer()
    return optimizer.minimize(loss)

def model_fn(features, labels, mode):
    _encoder = encoder(features)
    _decoder = decoder(_encoder)
    _loss = loss(labels, _decoder)
    _train = train(_loss)
    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=_decoder,
                                      loss=_loss,
                                      train_op=_train)

data = input_data.read_data_sets(".", one_hot=True)
display.clear_output()
# remove current log dir
shutil.rmtree('logs', ignore_errors=True)

def input_fn():
    if data.train.epochs_completed <= 10:
        features, labels = data.train.next_batch(100)
        return tf.constant(features), tf.constant(features)
    raise StopIteration

estimator = tf.estimator.Estimator(model_fn, "logs")
estimator.train(input_fn=input_fn)

这是一些示例输出

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'logs', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 0 into logs/model.ckpt.
INFO:tensorflow:loss = 0.0505481, step = 0
INFO:tensorflow:loss = 0.00319921, step = 0 (1.125 sec)
INFO:tensorflow:loss = 0.00277268, step = 0 (1.094 sec)
INFO:tensorflow:loss = 0.00275822, step = 0 (1.106 sec)
INFO:tensorflow:loss = 0.00275116, step = 0 (1.069 sec)
INFO:tensorflow:loss = 0.00275018, step = 0 (1.130 sec)
INFO:tensorflow:loss = 0.00274921, step = 0 (1.161 sec)
INFO:tensorflow:loss = 0.00274908, step = 0 (1.140 sec)
INFO:tensorflow:loss = 0.00274683, step = 0 (1.105 sec)
INFO:tensorflow:loss = 0.00274397, step = 0 (1.111 sec)

1 个答案:

答案 0 :(得分:3)

training op中,您需要设置global_step参数,这是为每个模型训练运行增加的步数计数器。所以改为:

optimizer.minimize(loss, global_step=tf.train.get_global_step())