在Tensorflow中将会话作为参数传递是否错误?

时间:2018-09-07 12:06:06

标签: python tensorflow machine-learning

可能有一个简单的答案。

我针对同一数据训练了几个纪元。批量迭代时,损失平稳地减小,第二,第三,第四次看到相同的数据时,损失急剧减小。

# train_epoch(epoch=1, session)
[batch 0] loss = 0.97
[batch 1] loss = 0.96
[batch 2] loss = 0.95
...
# return 

# train_epoch(epoch=2, session)
[batch 9] loss = 0.51
[batch 1] loss = 0.50
[batch 7] loss = 0.49
...
# return 

我发现什么时候出现了问题

(1)在一次train_epoch呼叫之后,损失可能会显着增加

(2)当我将批次相乘时:

def train_epoch(data)
    data = data * 1000

在每个新纪元开始时,急剧下降继续发生,尽管在该纪元中反复看到相同(n = 10)批次1k次(w /损失平稳下降)。我唯一可以得出的结论是,也许我应该返回会话对象?

这是模型的组织:

class Model():
    self._loss = xent()
    self._train_op = self._optimizer.apply_gradients()
    self._global_step = tf.Variable(0, name='global_step', trainable=False)

def load_model(session, mode):
    with tf.variable_scope("model", reuse=None, initializer=None):
        model = Model(mode)

    saver = tf.train.Saver()
    tf.global_variables_initializer().run()
    saver.restore(session, ckpt_path)
    return model, saver

def train_epoch(data, model, session):
    for batch in data:
        feed_dict = { model._field = batch[i] }
        _, loss, _ = session.run([m._train_op, m._loss, m._global_step], feed_dict)
    return avg_loss

with tf.Graph().as_default(), tf.Session(config=config) as session:

    mode = ["train", "decode"][0]       
    model, saver = load_model(session, mode)
    data = load_data()

    if mode == "decode":
        decode_epoch(session, model, data)
    else:
        while something:
            train_epoch(session, model, data)
            shuffle(data)
            if save:
                save_model(session, saver) # -> saver.save(session, ckpt_path)

0 个答案:

没有答案