基本上,我想知道以下代码是否安全。
我想使用tf.py_function
来调用一些scipy代码,在其中我要评估调用会话的操作(包括设置变量)。原因是,我想使用scipy代码来完成在tensorflow中难以编写的代码,但是我不想中断计算图。我至少有两个示例:一个在图形中使用LSODE隐式ODE求解器,另一个在图形中使用scipy的鲁棒最小化器(我正在考虑在py_function包装器中调用GPFlow优化器,该算法需要反复求解GP。优化问题)。
以下代码运行并返回我期望的值(值1)。但是我不知道它是否安全。
我猜想如果我还要在图表中的其他地方使用x
的值,则会带来不确定的行为。
from scipy.optimize import minimize
import tensorflow as tf
def build_func(session, y_and_grad, var):
pl = tf.placeholder(tf.float32)
assign = tf.assign(var, pl)
y, grad = y_and_grad
def func(x):
"""
This could be something that uses ops and vars in the same graph,
but requires iterative access to session calls.
E.g. using scipy.minimize with tensorflow to compute the gradients.
"""
x = x.numpy()
def fun_and_jac(x):
session.run(assign,{pl:x})
y, jac = session.run(y_and_grad)
return y, jac
res = minimize(fun_and_jac, x0=x, jac=True)
return res.x
return func
with tf.Session(graph=tf.Graph()) as session:
x = tf.Variable([0.], dtype=tf.float32)
session.run(x.initializer)
y = (x - 1.)**2
grad = tf.gradients(y, x)[0]
func = build_func(session, [y, grad], x)
opt_x = tf.py_function(func, [x], [tf.float32])
print(session.run(opt_x))