Tensorflow:将现有图形多次复制到新图形中

时间:2018-11-08 14:11:18

标签: python tensorflow

我想将现有的tensorflow图粘贴到新图中。

假设我创建了一个计算y = tanh(x @ w)的图形

import tensorflow as tf
import numpy as np

def some_function(x):
    w = tf.Variable(initial_value=np.random.randn(4, 5), dtype=tf.float32)
    return tf.tanh(x @ w)

x = tf.placeholder(shape=(None, 4), dtype = tf.float32)
y = some_function(x)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
val_x = np.random.randn(3, 4)
val_y, = sess.run([y], feed_dict={x: val_x})

太好了。现在假设我已经丢失了生成该图的代码,但是我仍然可以访问变量(xy)。现在,我要获取此图(使用w的当前值),并将其复制到新图两次(两个路径应共享相同的w),以便现在我计算d = tf.reduce_sum((tanh(x1 @ w)-tanh(x2 @ w))**2)添加行:

# Starting with access to tensors: x, y
<SOMETHING HERE>
d = tf.reduce_sum((y1-y2)**2)
val_x1 = np.random.randn(3, 4)
val_x2 = np.random.randn(3, 4)
val_d = sess.run([d], feed_dict = {x1: val_x1, x2: val_x2})

要完成这项工作,我需要为<SOMETHING HERE>填写些什么? (很明显,没有重新创建第一个图)

1 个答案:

答案 0 :(得分:1)

Graph Editor模块可以帮助您进行此类操作。它的主要缺点是在修改图形时您无法运行会话。但是,您可以在该会话中设置检查点,修改图形并在需要时将其还原。

所需的问题是,除了不想复制变量外,基本上需要复制一个子图。因此,您可以简单地排除变量类型(主要是VariableVariableV2甚至是VarHandleOp,尽管我在TensorFlow code中发现了一些变量类型)。您可以使用以下函数来做到这一点:

import tensorflow as tf

# Receives the outputs to recalculate and the input replacements
def replicate_subgraph(outputs, mappings):
    # Types of operation that should not be replicated
    # Taken from tensorflow/python/training/device_setter.py
    NON_REPLICABLE = {'Variable', 'VariableV2', 'AutoReloadVariable',
                      'MutableHashTable', 'MutableHashTableV2',
                      'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2',
                      'MutableDenseHashTable', 'MutableDenseHashTableV2',
                      'VarHandleOp', 'BoostedTreesEnsembleResourceHandleOp'}
    # Find subgraph ops
    ops = tf.contrib.graph_editor.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
    # Exclude non-replicable operations
    ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
    # Make subgraph viewitems
    sgv = tf.contrib.graph_editor.make_view(*ops_replicate)
    # Make the copy
    _, info = tf.contrib.graph_editor.copy_with_input_replacements(sgv, mappings)
    # Return new outputs
    return info.transformed(outputs)

对于与您类似的示例(我对其进行了一些编辑,因此很容易看到输出正确,因为第二个值是第一个值的十倍)。

import tensorflow as tf

def some_function(x):
    w = tf.Variable(initial_value=tf.random_normal((5,)), dtype=tf.float32)
    return 2 * (x * w)

x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
y1 = some_function(x1)
y2, = replicate_subgraph([y1], {x1: x2})
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')

输出:

[ 2.3356955   2.277849    0.58513653  2.0919807  -0.15102367]
[23.356955  22.77849    5.851365  20.919807  -1.5102367]

编辑:

这是使用tf.make_template的另一种解决方案。这要求您实际上具有该函数的代码,但这是支持子图重用的更简洁,更“正式”的方式。

import tensorflow as tf

def some_function(x):
    w = tf.get_variable('W', (5,), initializer=tf.random_normal_initializer())
    # Or if the variable is only local and not trainable
    # w = tf.Variable(initial_value=tf.random_normal(5,), dtype=tf.float32, trainable=False)
    return 2 * (x * w)

x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
some_function_tpl = tf.make_template('some_function', some_function)
y1 = some_function_tpl(x1)
y2 = some_function_tpl(x2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')