如何重新初始化张量流变量?

时间:2019-12-26 11:38:50

标签: tensorflow tensorflow2.0

在tensorflow 1中,我曾经运行tf.global_variables_initializer()以便(重新)初始化图中的所有变量。我离开tensorflow已有一段时间了,我没有看到可以重新初始化变量的方法。我正在运行constexpr,并且我希望能够多次调用model.fit,而不会在开始之前正在进行的当前进展中保持温暖

1 个答案:

答案 0 :(得分:2)

tf.keras API中,您可以使用get_weights()set_weights()方法获取/设置Model对象的权重值。

请参见下面的简单示例:

import numpy as np
import tensorflow as tf

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(3,)))
model.compile(loss="mse", optimizer="sgd")
initial_weight_values = model.get_weights()  # <-- Remember the initial weight values.

num_examples = 100
xs = np.ones([num_examples, 3])
ys = np.zeros([num_examples, 1])
model.fit(xs, ys, epochs=10)

model.set_weights(initial_weight_values)  # <-- Restore the weight values to the initial ones.
model.fit(xs, ys, epochs=10)  # <-- The new fit() call should start from a high loss value like the one in the previous fit() call.

注意:该方法每次将权重值重新初始化为完全相同的值。如果每次都需要不同的值,则可以使用类似numpy.random的值来生成与initial_weight_values相同的dtype和形状的值。