我需要在构建后将intput更改为tesor。以简化示例为例:x
(constant=42.0)
,s
(x^2)
和x_new
(constant=4.0)
。
我想将s
的输入从x
更改为x_new
。执行此操作后,我期待s.eval() == 16.0
x = tf.constant(42.0, name='x')
s = tf.square(x, name='s')
x_new = tf.constant(4.0, name='x_new')
tf.get_default_graph().as_graph_def()
Out[6]:
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 42.0
}
}
}
}
node {
name: "s"
op: "Square"
input: "x"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "x_new"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 4.0
}
}
}
}
versions {
producer: 24
}
我尝试过使用tf.contrib.graph_editor.reroute_inputs
,但就我而言,我无法弄清楚如何处理它返回的子图。
我也尝试过使用tf.import_graph_def
在这个git问题(https://github.com/tensorflow/tensorflow/issues/1758)中模糊地描述,但无论我尝试多少种方式,我都没有{{1}将输入从s
更改为x
。
任何人都知道如何使用这些方法中的任何一种来完成这个简单的例子吗?
答案 0 :(得分:0)
您可以使用占位符代替常量。
例如:
import tensorflow as tf
x = tf.placeholder(tf.int32, shape=[])
s = tf.square(x)
with tf.Session() as sess:
print(s.eval({x: 5}))
print(s.eval({x: 4}))
答案 1 :(得分:0)
我没有完整的答案。我也尝试过使用图形编辑器,但无济于事。我试图手动更改以下操作的输入列表(在本例中为s):
x = tf.constant(42.0, name='x')
s = tf.square(x, name='s')
x_new = tf.constant(4.0, name='x_new')
s.inputs._inputs[0] = x_new
#if you want to get all operations consuming x...
outputConsumers = tf.contrib.graph_editor.get_consuming_ops([x])
但这不会改变执行方式,而且似乎还涉及其他簿记。
您走近了吗?
编辑
我不鼓励在生产代码中使用它,但是用于操作的python tensorflow包装器代码具有此内置函数,似乎可以完成工作。
def _update_input(self, index, tensor):
"""Update the input to this operation at the given index.
NOTE: This is for TF internal use only. Please don't use it.
答案 2 :(得分:0)
所以这是我在尝试使用预先训练的网络时经常遇到的问题。正如您在问题中提到的那样,一种方法是按照Github issue中所述的方法来“ import_graph_def”。 Google一直在指点我,这个问题没有一个清晰的例子,因此我将在此处发布一个最小的解决方案。
import tensorflow as tf
with tf.compat.v1.Session() as sess:
x = tf.constant(42.0, name="x")
s = tf.square(x, name="s")
print(sess.run(s))
scope = "test"
x_new = tf.constant(4.0, name="{}/x".format(scope))
tf.import_graph_def(tf.get_default_graph().as_graph_def(), name=scope, input_map={'x': x_new})
print(sess.run("{}/s:0".format(scope)))
请注意,如果您不提供范围,则根据docs,默认值为“导入”。
相反,如果您需要保留一个图形然后进行编辑(或从其他人加载保留的图形),则可以保存该图形并重新加载它(基于answer)
import tensorflow as tf
scope = "test"
graph_filename = "test.pb"
with tf.compat.v1.Session() as sess:
x = tf.constant(42.0, name="x")
s = tf.square(x, name="s")
print(sess.run(s))
with tf.gfile.GFile(graph_filename, 'wb') as outfile:
outfile.write(tf.get_default_graph().as_graph_def().SerializeToString())
with tf.compat.v1.Session() as sess:
x_new = tf.constant(4.0, name="{}/x".format(scope))
with tf.gfile.GFile(graph_filename, 'rb') as infile:
graph_def = tf.GraphDef()
bytes_read = graph_def.ParseFromString(infile.read())
tf.import_graph_def(graph_def, name=scope, input_map={'x': x_new})
print(sess.run("{}/s:0".format(scope)))