我的一些图形定义用于训练。看起来像这样
android-build
然后我为推理定义了一部分。目前看起来像
*.so
但是with tf.variable_scope('RNN', initializer=tf.contrib.layers.xavier_initializer()):
self.rnn_cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
self.init_state = tf.get_variable('init', [1, HID_SZ], tf.float32)
self.init_state_train = tf.tile(self.init_state, [SZ_BATCH, 1])
outputs, state = tf.nn.dynamic_rnn(self.rnn_cell, emb, initial_state=self.init_state_train, dtype=tf.float32, time_major=True)
导致错误:
with tf.variable_scope("", reuse=True):
[...]
self.rnn_infer = tf.get_variable('RNN/rnncell')
inputs_single = tf.expand_dims(emb_single, 0)
input_state_ = tf.expand_dims(self.input_state, 0)
output, hidden = self.rnn_infer(inputs_single, input_state_, name='rnncall')
我正在尝试重用分配给tf.get_variable('RNN/rnncell')
的变量进行推断,该怎么做?
答案 0 :(得分:1)
关键点在于,当您制作一个单元格并将其放入rnn时,将照常在图形上创建权重和操作。因此,您可以像平常一样恢复重量。
import tensorflow as tf
import numpy as np
import os
def build_and_train():
HID_SZ = 1
graph = tf.Graph()
sess = tf.Session(graph=graph)
ones = np.ones([2, 3])
with graph.as_default():
in_ = tf.placeholder(tf.float32, [2, 3])
cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
state = tf.zeros([2, HID_SZ])
out, state = cell(in_, state)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, os.getcwd() + '\\model.ckpt')
print('Cell output after training')
print(sess.run(out, feed_dict={in_:ones}))
def infer():
HID_SZ = 1
graph = tf.Graph()
sess = tf.Session(graph=graph)
ones = np.ones([2, 3])
with graph.as_default():
in_ = tf.placeholder(tf.float32, [2, 3])
cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
state = tf.zeros([2, HID_SZ])
out, state = cell(in_, state)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print('random cell output')
print(sess.run(out, feed_dict={in_:ones}))
saver.restore(sess, 'model.ckpt')
print('Trained cell output')
print(sess.run(out, feed_dict={in_:ones}))
build_and_train()
infer()
这将输出:
Cell output after training
[[0.02710133]
[0.02710133]]
random cell output
[[0.2458247]
[0.2458247]]
Trained cell output
[[0.02710133]
[0.02710133]]