使用另一个图形中的导入张量初始化变量

时间:2017-09-08 08:21:35

标签: python-3.x tensorflow deep-learning

我在python3中使用tensorflow(版本:v1.1.0-13-g8ddd727 1.1.0)(Python 3.4.3(默认,2016年11月17日,01:08:31)[GCC 4.8.4]在linux上) ,它是从源和基于GPU安装的。

我想知道是否可以使用导入的张量从另一个会话初始化变量,因为tensorflow文档没有提及它,我在stackoverflow上找到了它。

train_dir = './gan/train_logs'
    ckpt = tf.train.latest_checkpoint(train_dir)
    filename = ".".join([ckpt, 'meta'])
    print(filename)
    saver = tf.train.import_meta_graph(filename)
    saver.restore(sess, ckpt)
    test = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')

这里成功导入了张量,我想用它们来初始化相同的生成器。

感谢您的帮助!

1 个答案:

答案 0 :(得分:0)

您所要做的就是创建tf.assign操作。

所以你这样做:

old_weights = .... # your loading

new_weights = tf.Variable( ... ) # any initialisation here!

initialise_new_weights = tf.assign(new_weights, old_weights)

with tf.train.MonitoredSession() as sess:
  # at this point new_weights are randomly initialised
  sess.run(initialise_new_weight) # now they are initialised to your values

或者您可以直接传递初始化参数

old_weights = .... # your loading

new_weights = tf.Variable( ..., initializer = tf.constant_initialiser(old_weights) ) 

with tf.train.MonitoredSession() as sess:
  # they are initialised to your values