我正在TensorFlow中进行基本的轨道力学仿真。当“行星”太靠近“太阳”时(当x,y接近(0,0)时),TensorFlow在除法过程中会获得异常(这可能是有道理的)。它以某种方式在异常期间返回异常,导致其完全失败。
我尝试使用tf.where
有条件地将这些除数用NaN
替换为零,但是实际上它遇到了相同的错误。我还尝试过使用tf.div_no_nan而不是NaN
来获得零,但这会得到完全相同的错误。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def gravity(state, t):
print(len(tf.unstack(state)))
x, y, vx, vy = tf.unstack(state)
# Error is related to next two lines
fx = -x/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
fy = -y/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
dvx = fx
dvy = fy
return tf.stack([vx, vy, dvx, dvy])
# Num simulations
size = 100
# Initialize at same position with varying y-velocity
init_state = tf.stack([tf.constant(-1.0,shape=(size,)),tf.zeros((size)),tf.zeros((size)),tf.range(0,10,.1)])
t = np.linspace(0, 10, num=5000)
tensor_state, tensor_info = tf.contrib.integrate.odeint(
gravity, init_state, t, full_output=True)
init = tf.global_variables_initializer()
with tf.Session() as sess:
state, info = sess.run([tensor_state, tensor_info])
state = tf.transpose(state, perm=[1,2,0]).eval()
x, y, vx, vy = state
for i in range(10):
plt.figure()
plt.plot(x[i], y[i])
plt.scatter([0],[0])
我实际上得到了
...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...
During handling of the above exception, another exception occurred:
...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...
我希望除法结果为NaN
或无穷大,然后像通常期望的那样对数值积分进行传播。
答案 0 :(得分:0)
您可以尝试
with tf.Session() as sess:
sess.run(init)
try:
state, info = sess.run([tensor_state, tensor_info])
except tf.errors.InvalidArgumentError:
state = #Whatever values/shape you need
我不知道这是否适合您的情况,但也许您可以添加一些小常数以避免被零除。