是否可以在TensorFlow py_function中使用调用会话?

时间:2019-11-29 00:25:03

标签: python tensorflow gpflow

基本上,我想知道以下代码是否安全。 我想使用tf.py_function来调用一些scipy代码,在其中我要评估调用会话的操作(包括设置变量)。原因是,我想使用scipy代码来完成在tensorflow中难以编写的代码,但是我不想中断计算图。我至少有两个示例:一个在图形中使用LSODE隐式ODE求解器,另一个在图形中使用scipy的鲁棒最小化器(我正在考虑在py_function包装器中调用GPFlow优化器,该算法需要反复求解GP。优化问题)。

以下代码运行并返回我期望的值(值1)。但是我不知道它是否安全。 我猜想如果我还要在图表中的其他地方使用x的值,则会带来不确定的行为。

from scipy.optimize import minimize
import tensorflow as tf

def build_func(session, y_and_grad, var):
    pl = tf.placeholder(tf.float32)
    assign = tf.assign(var, pl)
    y, grad = y_and_grad
    def func(x):
        """
        This could be something that uses ops and vars in the same graph,
        but requires iterative access to session calls.
        E.g. using scipy.minimize with tensorflow to compute the gradients.
        """
        x = x.numpy()
        def fun_and_jac(x):
            session.run(assign,{pl:x})
            y, jac = session.run(y_and_grad)
            return y, jac
        res = minimize(fun_and_jac, x0=x, jac=True)
        return res.x
    return func

with tf.Session(graph=tf.Graph()) as session:
    x = tf.Variable([0.], dtype=tf.float32)
    session.run(x.initializer)
    y = (x - 1.)**2
    grad = tf.gradients(y, x)[0]
    func = build_func(session, [y, grad], x)
    opt_x = tf.py_function(func, [x], [tf.float32])
    print(session.run(opt_x))

0 个答案:

没有答案