使用在张量流中具有变量依赖性的自定义op替换图中的节点

时间:2016-10-05 20:17:31

标签: python graph tensorflow custom-operator

我试图用自定义操作替换图中完成的计算。

假设图表有一个常量A和权重变量W,我创建自定义操作来获取这两个输入并执行整个计算(除了重量更新的最后一步):< / p>

custom_op_tensor = custom_module.custom_op([A,W])
g_def = tf.get_default_graph().as_graph_def()
input_map = { tensor.name : custom_op_tensor }
train_op, = tf.import_graph_def(g_def, input_map=input_map, return_elements=[train_op])

导入图形def后,有两个W,一个来自原始图形def,另一个来自导入图形。当我们运行列车操作时,自定义操作会最终读取旧W并更新新的W。结果,梯度下降最终无法做正确的事情。

问题是custom_op的实例化需要输入权重张量W。新的W仅在导入后才知道。而且,导入需要自定义操作。 如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

您能否确定使用哪个版本的Tensorflow:r0.08,r0.09,r0.10,r0.11?

用另一个操作来改变图表中的操作是不可能的。 但是如果您可以访问W,您仍然可以在运行更新它的列车运行之前制作W的备份副本(使用deepcopy()from copy module)?

此致