我正在尝试对加载的图表进行推理:
ds_graph = load_graph(model)
graph_input = ds_graph.get_tensor_by_name('prefix/input_node:0')
graph_seqlength = ds_graph.get_tensor_by_name('prefix/input_lengths:0')
graph_output = ds_graph.get_tensor_by_name('prefix/output_node:0')
我正在迭代的变量是
inp[i]
sl[i]
在循环中
for i in range(num):
with tf.Session(graph=ds_graph) as sess:
logits = sess.run(graph_output,feed_dict={graph_input:inp[i],graph_seqlength:sl[i]})
logits = tf.nn.softmax(logits, dim=-1, name=None)
logits = sess.run(logits)
output_length=np.array([logits.shape[0]])
tf_greedy_path, _ = tf.nn.ctc_greedy_decoder(logits,output_length,merge_repeated=True)
tf_greedy_path = tf.convert_to_tensor([tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in tf_greedy_path])
greed_out = ndarray_to_text(sess.run(tf_greedy_path)[0][0])
return greed_out
我知道这个片段会在每次迭代时不断向图表中添加元素。但我不确定如何具体解决这个问题。
我的有限理解告诉我在循环之外创建图元素:
logits = tf.nn.softmax(graph_output, dim=-1, name=None)
tf_greedy_path, _ = tf.nn.ctc_greedy_decoder(logits,output_length,merge_repeated=True)
tf_greedy_path = tf.convert_to_tensor([tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in tf_greedy_path])
for i in range(num):
with tf.Session(graph=ds_graph) as sess:
sess.run(graph_output,feed_dict={graph_input:inp[i],graph_seqlength:sl[i]})
sess.run(logits)
output_length=np.array([logits.shape[0]])
greed_out = ndarray_to_text(sess.run(tf_greedy_path)[0][0])
但是我仍然需要处理在执行期间计算output_length的事实。不幸的是,ctc_greedy_decoder
并未将output_length
作为张量。或者我会通过tf.shape(logits)
答案 0 :(得分:1)
如果没有整个代码,很难回答,但是你是对的,你应该在进入循环之前将所有操作添加到图表中。似乎没有任何东西阻止你使用tensor shape
张量的graph_output
(顺便说一下,不需要中间调用,只需评估你感兴趣的张量,任何中级张量将自动计算):
import tensorflow as tf
graph_output = tf.placeholder(tf.float32, shape=[None, 1, 2]) # graph_output has a dynamic shape
logits = tf.nn.softmax(graph_output, dim=-1, name=None)
tf_greedy_path, _ = tf.nn.ctc_greedy_decoder(logits,[graph_output.shape[0]],merge_repeated=True)
tf_greedy_path = tf.convert_to_tensor([tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in tf_greedy_path])
for i in range(10):
with tf.Session() as sess:
print(sess.run(tf_greedy_path, feed_dict={graph_output:[[[1., 2.]], [[3., 4.]]]})))