如何提高张量流sess.run()的进动速度

时间:2018-07-04 05:02:57

标签: session tensorflow

我正在尝试实现由tensorflow实现的递归最小二乘过滤器(RLS)。使用约30万条数据进行训练,耗时约3分钟。有什么方法可以减少时间消耗。

def rls_filter_graphic(x, d, length, Pint, Wint):  # Projection matrix
    P = tf.Variable(Pint, name='P')
    # Filter weight
    w = tf.Variable(Wint, name='W')
    # Get output of filter
    y = tf.matmul(x, w)
    # y mast round,if(y>0) y=int(y) else y=int(y+0.5)
    y = tf.round(y)
    # get expect error which means expect value substract out put value
    e = tf.subtract(d, y)
    # get gain it is equal as k=P(n-1)x(n)/[lamda+x(n)P(n-1)x(n-1)]
    tx = tf.transpose(x)
    Px = tf.matmul(P, tx)
    xPx = tf.matmul(x, Px)
    xPx = tf.add(xPx, LAMDA)
    k = tf.div(Px, xPx)
    # update w(n) with k(n) as w(n)=w(n-1)+k(n)*err
    wn = tf.add(w, tf.matmul(k, e))
    # update P(n)=[P(n-1)-K(n)x(n)P(n-1)]/lamda
    xP = tf.matmul(x, P)
    kxP = tf.matmul(k, xP)
    Pn = tf.subtract(P, kxP)
    Pn = tf.divide(Pn, LAMDA)
    update_P = tf.assign(P, Pn)
    update_W = tf.assign(w, wn)
    return y, e, w, k, P, update_P, update_W

这将在此函数中调用

def inter_band_predict(dataset, length, width, height):
    img = np.zeros(width * height)
    # use iterator get the vectors
    itr = dataset.make_one_shot_iterator()
    (x, d) = itr.get_next()
    # build the predict graphics
    ini_P, ini_W, = rls_filter_init(0.001, length)
    y_p, err, weight, kn, Pn, update_P, update_W = rls_filter_graphic(x, d, length, ini_P, ini_W)
    # err=tf.matmul(x,ini_W)
    # init value
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        for i in range(height*width):
            [e, trainP, trainW] = sess.run([err, update_P, update_W])
            img[i] = e
        return img

我发现每个循环调用sess.run()非常昂贵,有什么办法可以避免这种情况。 顺便说一下,数据集是这样的形式

   (x,     d)
(
 [0.1,1.0],[1.0]
 [0.0,2.0],[2.0]
 [0.2,3.0],[1.0]
  ……
)

0 个答案:

没有答案