我创建了一个示例脚本,在其中创建了用于前进和后退方向的GRUCells,并将其传递给bidirectional_rnn层。
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util
import os
X_batch = np.array([
[[0., 1., 2.], [8., 2., 1.], [9., 8., 7.]],
[[3., 4., 5.], [9., 7., 4.], [0., 0., 0.]],
[[6., 7., 8.], [3., 6., 7.], [6., 5., 4.]],
[[9., 0., 1.], [0., 0., 0.], [0., 0., 0.]]
])
batch = 4
n_steps = 3
input_size = 3
inputs = tf.placeholder(tf.float32, [batch, n_steps, input_size])
def biGRU(inputs, n_hidden, batch_size):
gru_fw = tf.nn.rnn_cell.GRUCell(n_hidden)
gru_bw = tf.nn.rnn_cell.GRUCell(n_hidden)
output, _states = tf.nn.bidirectional_dynamic_rnn(gru_fw, gru_bw, inputs, dtype=tf.float32)
final_outputs = tf.concat([output[0], output[1]], 2)
return final_outputs
biGRU_model = biGRU(inputs, 4, batch)
with tf.Session() as sess:
tb_dir = 'test_tensorboard/'
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
check_write = tf.summary.FileWriter(tb_dir, sess.graph)
init = tf.global_variables_initializer()
init.run()
sess.run(biGRU_model, feed_dict={inputs: X_batch})
variables_names =[v.name for v in tf.trainable_variables()]
values = sess.run(variables_names)
for k,v in zip(variables_names, values):
print(k, v)
output_node_name = "concat"
# remove training nodes
removed_train_graph_def = graph_util.remove_training_nodes(sess.graph_def, protected_nodes=None)
# convert variable nodes to const nodes
const_graph_def = graph_util.convert_variables_to_constants(sess, removed_train_graph_def, output_node_name.split(","))
gd_dir = 'test_graphdef/'
if not os.path.exists(gd_dir):
os.makedirs(gd_dir)
tf.train.write_graph(const_graph_def, gd_dir, 'test_gru.pbtxt', as_text=True)
tf.train.write_graph(const_graph_def, gd_dir, 'test_gru.pb', as_text=False)
此脚本将打印可训练的变量并将日志转储到test_tensorboard文件夹以在tensorboard中查看,并将pb和pbtxt文件转储到test_graphdef文件夹。请查看图片 gru_cells_in_fw_bw.png在这里可以看到grucell,请查看kernal_bias_in_grucell.png在这里可以看到内核和偏差。示例脚本将打印此gru单元内部的权重和偏差。
我的目标是解析tf.nn.rnn_cell.GRUCell并获取所有参数。这里的参数可能是内核和偏差,还有其他参数。 您可以在https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell
上看到除内核和偏差以外的其他参数。当我在c ++中解析“ test_gru.pb”时,没有得到op:“ GRUCell”作为一个opnode。例如,如果在模型中创建tf.nn.conv2d,我们将获得op:“ Conv2D”作为一个OpNode,这样我就很容易获得参数。
所以最后我的问题是GRUCell内的操作是什么?换句话说,如果我查看Netron中的test_gru.pb,我会看到很多opnode都执行GRUCell操作,那么我如何识别那些设置的opnode? p>
谢谢!