问题是一个Tensorflow While循环( tf.while_loop )随着时间的流逝而变慢。该循环应返回一些矩阵。我通过字典提供所有输入。
我知道问题很可能是由于一遍又一遍地添加操作而污染了图形。我是TF初学者,对我来说,导致图形污染的原因并不明显。我们非常感谢您的帮助。
def predict(self, actions, ...):
feed_dict = {
self.agent.actions: actions.reshape(-1, self.kwargs["dim_actions"]),
...
}
states_mu, states_var = self.session.run(self.agent.predict_states(), feed_dict=feed_dict)
return states_mu, states_var
def predict_states(self):
...
def loop_cond(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
return i < self.episode_length
def loop_body(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
state_mu_i = state_mus[-1][None, :]
...
state_var_tf = state_vars_tf[-1][None, :, :]
#Some math operations
...
new_state_mu = state_mu_i + delta_mu
new_state_var = state_var_i + delta_var + inp_out_cov
new_mu_tf, new_var_tf, inp_tf_cov = some_transform(
new_state_mu, ....)
state_mus = tf.concat([state_mus, new_state_mu], 0)
...
state_vars_tf = tf.concat([state_vars_tf, new_var_tf], 0)
i += 1
return i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov
loop_step = tf.constant(0, tf.int32)
init_mus_tf, init_vars_tf, inp_tf_cov = some_transform(
self.state_mu, self.state_var, self.dim_angles)
loop_vars = [
loop_step,
self.state_mu,
self.state_var,
init_mus_tf,
init_vars_tf,
inp_tf_cov]
shapes = [loop_step.get_shape(),
tf.TensorShape([None, self.dim_states]),
tf.TensorShape([None, self.dim_states, self.dim_states]),
tf.TensorShape([None, self.dim_states_tf]),
tf.TensorShape([None, self.dim_states_tf, self.dim_states_tf]),
inp_tf_cov.get_shape()]
_, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov = tf.while_loop(
loop_cond,
loop_body,
loop_vars=loop_vars,
shape_invariants=shapes)
return state_mus_tf[1:], state_vars_tf[1:]
该循环被多次调用。它在运行中会变慢,即在每次迭代后,甚至在重复调用后甚至会变慢。每次运行的迭代速度从上次运行结束的地方开始。 例如,在第一次运行的开始,每个迭代花费1秒,在第一次运行的结束,每个迭代花费3秒。在第二次运行开始时,每次迭代需要3秒,...直到使其无法运行(例如,每次迭代100秒)。
答案 0 :(得分:0)
该代码似乎大部分都很好,但是在创建类的实例时(或在其他一些初始化步骤中),并且应该将返回值存储在类属性中,您仅应调用predict_states
一次。例如:
def __init__(self, ...):
# ...
self.states_mu_tf, self.states_var_tf = self.agent.predict_states()
然后,您在predict
中使用这些属性:
states_mu, states_var = self.session.run((self.states_mu_tf, self.states_var_tf),
feed_dict=feed_dict)
那样,您将不会在图形中重新创建操作。