如何重用rnn单元进行推理

时间:2018-10-25 10:07:43

标签: tensorflow

我的一些图形定义用于训练。看起来像这样

android-build

然后我为推理定义了一部分。目前看起来像

*.so

但是with tf.variable_scope('RNN', initializer=tf.contrib.layers.xavier_initializer()): self.rnn_cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell') self.init_state = tf.get_variable('init', [1, HID_SZ], tf.float32) self.init_state_train = tf.tile(self.init_state, [SZ_BATCH, 1]) outputs, state = tf.nn.dynamic_rnn(self.rnn_cell, emb, initial_state=self.init_state_train, dtype=tf.float32, time_major=True) 导致错误:

with tf.variable_scope("", reuse=True):
    [...]
    self.rnn_infer = tf.get_variable('RNN/rnncell')
    inputs_single = tf.expand_dims(emb_single, 0)
    input_state_ = tf.expand_dims(self.input_state, 0)
    output, hidden = self.rnn_infer(inputs_single, input_state_, name='rnncall')

我正在尝试重用分配给tf.get_variable('RNN/rnncell')的变量进行推断,该怎么做?

1 个答案:

答案 0 :(得分:1)

关键点在于,当您制作一个单元格并将其放入rnn时,将照常在图形上创建权重和操作。因此,您可以像平常一样恢复重量。

import tensorflow as tf
import numpy as np
import os


def build_and_train():
    HID_SZ = 1
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    ones = np.ones([2, 3])

    with graph.as_default():
        in_ = tf.placeholder(tf.float32, [2, 3])
        cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
        state = tf.zeros([2, HID_SZ])
        out, state = cell(in_, state)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

    saver.save(sess, os.getcwd() + '\\model.ckpt')
    print('Cell output after training')
    print(sess.run(out, feed_dict={in_:ones}))

def infer():
    HID_SZ = 1
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    ones = np.ones([2, 3])

    with graph.as_default():
        in_ = tf.placeholder(tf.float32, [2, 3])
        cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
        state = tf.zeros([2, HID_SZ])
        out, state = cell(in_, state)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

    print('random cell output')
    print(sess.run(out, feed_dict={in_:ones}))

    saver.restore(sess, 'model.ckpt')

    print('Trained cell output')
    print(sess.run(out, feed_dict={in_:ones}))


build_and_train()
infer()

这将输出:

Cell output after training
[[0.02710133]
 [0.02710133]]
random cell output
[[0.2458247]
 [0.2458247]]
Trained cell output
[[0.02710133]
 [0.02710133]]