我有一个很大的张量流模型(不使用Keras)。我知道我可以保存和恢复张量流模型,但我需要的是不同的。
我想将模型的所有可训练参数存储在内存中,将所有可训练参数设置为我已有的另一组值,并进行预测,然后将先前存储的参数恢复回模型。 / p>
伪代码:
original_variables = tf.trainable_variables(scope=None)
new_variables = original_variables - lr * gradients
set_trainable_variables(new_variables)
y = model.predict(X)
set_trainable_variables(original_variables)
请注意,set_trainable_variables(vars)函数仅假设,
我的问题是:如何实现set_trainable_variables(vars)?