在tensorflow C ++中解析tf.nn.rnn_cell.GRUCell并了解GRUCell内部的操作是什么?

时间:2018-10-26 14:29:01

标签: python c++ tensorflow rnn

我创建了一个示例脚本,在其中创建了用于前进和后退方向的GRUCells,并将其传递给bidirectional_rnn层。

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util
import os

X_batch = np.array([
                    [[0., 1., 2.], [8., 2., 1.], [9., 8., 7.]],
                    [[3., 4., 5.], [9., 7., 4.], [0., 0., 0.]],
                    [[6., 7., 8.], [3., 6., 7.], [6., 5., 4.]],
                    [[9., 0., 1.], [0., 0., 0.], [0., 0., 0.]]
                   ])

batch = 4
n_steps = 3
input_size = 3
inputs = tf.placeholder(tf.float32, [batch, n_steps, input_size])

def biGRU(inputs, n_hidden, batch_size):
    gru_fw = tf.nn.rnn_cell.GRUCell(n_hidden)
    gru_bw = tf.nn.rnn_cell.GRUCell(n_hidden)
    output, _states = tf.nn.bidirectional_dynamic_rnn(gru_fw, gru_bw, inputs, dtype=tf.float32)
    final_outputs = tf.concat([output[0], output[1]], 2)
    return final_outputs

biGRU_model = biGRU(inputs, 4, batch)

with tf.Session() as sess:
    tb_dir = 'test_tensorboard/'
    if not os.path.exists(tb_dir):
      os.makedirs(tb_dir)
    check_write = tf.summary.FileWriter(tb_dir, sess.graph)
    init = tf.global_variables_initializer()
    init.run()
    sess.run(biGRU_model, feed_dict={inputs: X_batch})
    variables_names =[v.name for v in tf.trainable_variables()]
    values = sess.run(variables_names)
    for k,v in zip(variables_names, values):
        print(k, v)

    output_node_name = "concat"
    # remove training nodes
    removed_train_graph_def = graph_util.remove_training_nodes(sess.graph_def, protected_nodes=None)
    # convert variable nodes to const nodes
    const_graph_def = graph_util.convert_variables_to_constants(sess, removed_train_graph_def, output_node_name.split(","))
    gd_dir = 'test_graphdef/'
    if not os.path.exists(gd_dir):
      os.makedirs(gd_dir)
    tf.train.write_graph(const_graph_def, gd_dir, 'test_gru.pbtxt', as_text=True) 
    tf.train.write_graph(const_graph_def, gd_dir, 'test_gru.pb', as_text=False)

此脚本将打印可训练的变量并将日志转储到test_tensorboard文件夹以在tensorboard中查看,并将pb和pbtxt文件转储到test_graphdef文件夹。请查看图片 gru_cells_in_fw_bw.png在这里可以看到grucell,请查看kernal_bias_in_grucell.png在这里可以看到内核和偏差。示例脚本将打印此gru单元内部的权重和偏差。

我的目标是解析tf.nn.rnn_cell.GRUCell并获取所有参数。这里的参数可能是内核和偏差,还有其他参数。 您可以在https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell

上看到除内核和偏差以外的其他参数。

当我在c ++中解析“ test_gru.pb”时,没有得到op:“ GRUCell”作为一个opnode。例如,如果在模型中创建tf.nn.conv2d,我们将获得op:“ Conv2D”作为一个OpNode,这样我就很容易获得参数。

所以最后我的问题是GRUCell内的操作是什么?换句话说,如果我查看Netron中的test_gru.pb,我会看到很多opnode都执行GRUCell操作,那么我如何识别那些设置的opnode?

谢谢!

0 个答案:

没有答案