我正在尝试使用freeze_graph.py工具保存模型,但遇到了问题。
我的张量流图中有一个变量,我指定使用tf.assign,或者在每次推断之前输入。我需要它保持变量,因为tf.assign需要一个可变张量,你也不能提供给const,但freeze_graph脚本将所有变量转换为常量。
我注意到freeze_graph有白名单和黑名单参数,但我不能在我的生活中找到关于这些是什么或如何使用它们的任何文档。我能在这做什么?
编辑:
NoneType
和single_c
是我要保留的变量:
single_h
因为我使用以下方式分配它们:
single_c = tf.Variable(tf.random_uniform([num_lstm_cells], 0, 1), trainable=True)
expanded_c = tf.reshape(single_c, [1, num_lstm_cells])
batched_c = tf.tile(expanded_c, tiling_shape, name='c')
single_h = tf.Variable(tf.random_uniform([num_lstm_cells], 0, 1), trainable=True)
expanded_h = tf.reshape(single_h, [1, num_lstm_cells])
batched_h = tf.tile(expanded_h, tiling_shape, name='h')
state = tf.contrib.rnn.LSTMStateTuple(batched_c, batched_h)
我使用
提供restore_c = tf.assign(single_c, c_holder)
restore_h = tf.assign(single_h, h_holder)
state
如果_, er, new = sess.run([train_nn_step, error, new_state], feed_dict={batch_ph: batch_size, prob: training_dropout, state: new})
和single_c
成为常量