例如,我想恢复一个名为“ W1”的砝码,该砝码保存在ckpt文件中。
接下来,我定义一个名为add_layer
的函数,如下所述。
并使用此add_layer函数构建我的网络。
我的代码的一行是:
...
...
layer_n = add_layer( i, 2048, 2048, N, trainable=True )
...
问题是,如何将权重'W1'恢复为add_layer中定义的'WN'。
def add_layer(input_x, input_dim, output_dim, layer_num, trainable=True):
name_W = 'W' + layer_num
name_B = 'B' + layer_num
with tf.variable_scope( 'layer_' + layer_num ):
W1 = tf.get_variable( name_W, shape=[input_dim, output_dim],
initializer=tf.keras.initializers.lecun_normal(),
trainable=trainable )
B1 = tf.get_variable( name_B, shape=[1, output_dim],
initializer=tf.constant_initializer( value=0,
dtype=tf.float32 ),
trainable=trainable )
output = tf.nn.selu( tf.add( tf.matmul( input_x, W1 ), B1 ) )
return output