如果我将输入管道从feed_dict {...}更改为tf.data.dataset,那么在每次迭代后的训练过程中,如何更改网络中参数的值。为了澄清,旧代码看起来像这样:
# Define Training Step:
# model is some class that defines graph.
def train_step(x_batch, y_batch, var):
feed_dict = {
model.input : x_batch,
model.labels : y_batch,
model.var_to_change : var,
}
_, step, summaries, loss, accuracy = sess.run(
[train_op, global_step, model.cross_entropy, model.accuracy],
feed_dict)
# Training:
var_new = 0
for i in range(num_epochs):
batch = mnist.train.next_batch(batch_size)
train_step(batch[0], batch[1], var_new)
var_new = something_new_for_each_iteration
对于新内容,它看起来像这样:
model = create_model(dataset.inputs, dataset.outputs)
# where model.train returns tf.group(update_losses, train_op, global_step)
# Training
for step in range(num_epochs):
fetches = {"train": model.train}
results = sess.run(fetches, options=options)
谢谢!
答案 0 :(得分:0)
取决于您要实现的目标,也许可以使用以下选项之一:
您可以为此参数创建单独的模型功能,并在训练期间使用Dataset.from_generator()
如果可以从上一步中计算出变量,则可以在图形中创建一个变量,并使用tf.assign()操作对其进行更新。在下一批期间,您可以阅读和使用更新后的值。
答案 1 :(得分:0)
我想我解决了这个问题:我将参数初始化为具有特定名称的tf.Variable()。然后,在运行会话后,我循环遍历tf.global_variables()来查找它,并使用variable.load(new_variable,sess)更改其值。我将变量包括在摘要中,并且其值正在更改。同样,此插值参数添加到网络的图层权重也开始更新。