TensorFlow:训练时更改变量

时间:2018-12-05 15:38:51

标签: python tensorflow input

如果我将输入管道从feed_dict {...}更改为tf.data.dataset,那么在每次迭代后的训练过程中,如何更改网络中参数的值。为了澄清,旧代码看起来像这样:

# Define Training Step:  

# model is some class that defines graph.   

def train_step(x_batch, y_batch, var):

        feed_dict = {
            model.input         : x_batch,
            model.labels        : y_batch,
            model.var_to_change : var,
        }
        _, step, summaries, loss, accuracy = sess.run(
            [train_op, global_step, model.cross_entropy, model.accuracy],
            feed_dict)

# Training:  

var_new = 0 
for i in range(num_epochs):
        batch = mnist.train.next_batch(batch_size)
        train_step(batch[0], batch[1], var_new) 
        var_new = something_new_for_each_iteration

对于新内容,它看起来像这样:

model = create_model(dataset.inputs, dataset.outputs)
# where model.train returns tf.group(update_losses, train_op, global_step)

# Training

for step in range(num_epochs):

    fetches = {"train": model.train}
    results = sess.run(fetches, options=options)

谢谢!

2 个答案:

答案 0 :(得分:0)

取决于您要实现的目标,也许可以使用以下选项之一:

  1. 您可以为此参数创建单独的模型功能,并在训练期间使用Dataset.from_generator()

  2. 生成其值。
  3. 如果可以从上一步中计算出变量,则可以在图形中创建一个变量,并使用tf.assign()操作对其进行更新。在下一批期间,您可以阅读和使用更新后的值。

答案 1 :(得分:0)

我想我解决了这个问题:我将参数初始化为具有特定名称的tf.Variable()。然后,在运行会话后,我循环遍历tf.global_variables()来查找它,并使用variable.load(new_variable,sess)更改其值。我将变量包括在摘要中,并且其值正在更改。同样,此插值参数添加到网络的图层权重也开始更新。