如何使用张量作为函数的输入?(tensorflow)

时间:2019-04-13 23:48:17

标签: python tensorflow

我尝试使用tensorflow构造我的模型以求解微分方程,例如

dX/dt=f(\mu,X,t)

这里,\ mu是一个依赖X的函数,它很复杂,因此我想使用神经网络预测\ mu(X)。

首先,我的输入X传递密集层N以获得\ mu〜N(X)。 然后,我使用Runge-Kutta方法解决上面的ODE,该方法由代码定义:

def RK4(self, mu, X, t, dt=0.2):
    kX1=dt*self.f(mu, X, t)
    kX2=dt*self.f(mu, X+kX1/2, t+dt/2)
    kX3=dt*self.f(mu, X+kX2/2, t+dt/2)
    kX4=dt*self.f(mu, X+kX3, t+dt)
    X_next=X+(kX1+2*kX2+2*kX3+kX4)/6

    return X_next

请注意,self来自类变量。 当我直接将N(X)放入RK4时,会发生错误。

 Tensor objects are only iterable when eager execution is enabled. To iterate 
 over this tensor use tf.map_fn.

我对这个map_fn不熟悉。我的函数很复杂,因为它同时具有张量(\ mu,X)和浮点数(t,dt)。但据我所知,map_fn仅处理张量输入。有没有一种聪明的方法来处理这些输入?谢谢!

1 个答案:

答案 0 :(得分:0)

X_next= tf.map_fn(lambda x : self.RK4(x[0],x[1],x[2]),(self.mu, self.X, self.t), dtype=tf.float32)

将解决我的问题。实际上,tf.map_fn可以接收张量类型输入或浮点类型输入。从this link

可以看到该功能的用法