
时间:2018-07-04 05:02:57

标签: session tensorflow


def rls_filter_graphic(x, d, length, Pint, Wint):  # Projection matrix
    P = tf.Variable(Pint, name='P')
    # Filter weight
    w = tf.Variable(Wint, name='W')
    # Get output of filter
    y = tf.matmul(x, w)
    # y mast round,if(y>0) y=int(y) else y=int(y+0.5)
    y = tf.round(y)
    # get expect error which means expect value substract out put value
    e = tf.subtract(d, y)
    # get gain it is equal as k=P(n-1)x(n)/[lamda+x(n)P(n-1)x(n-1)]
    tx = tf.transpose(x)
    Px = tf.matmul(P, tx)
    xPx = tf.matmul(x, Px)
    xPx = tf.add(xPx, LAMDA)
    k = tf.div(Px, xPx)
    # update w(n) with k(n) as w(n)=w(n-1)+k(n)*err
    wn = tf.add(w, tf.matmul(k, e))
    # update P(n)=[P(n-1)-K(n)x(n)P(n-1)]/lamda
    xP = tf.matmul(x, P)
    kxP = tf.matmul(k, xP)
    Pn = tf.subtract(P, kxP)
    Pn = tf.divide(Pn, LAMDA)
    update_P = tf.assign(P, Pn)
    update_W = tf.assign(w, wn)
    return y, e, w, k, P, update_P, update_W


def inter_band_predict(dataset, length, width, height):
    img = np.zeros(width * height)
    # use iterator get the vectors
    itr = dataset.make_one_shot_iterator()
    (x, d) = itr.get_next()
    # build the predict graphics
    ini_P, ini_W, = rls_filter_init(0.001, length)
    y_p, err, weight, kn, Pn, update_P, update_W = rls_filter_graphic(x, d, length, ini_P, ini_W)
    # err=tf.matmul(x,ini_W)
    # init value
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        for i in range(height*width):
            [e, trainP, trainW] = sess.run([err, update_P, update_W])
            img[i] = e
        return img

我发现每个循环调用sess.run()非常昂贵,有什么办法可以避免这种情况。 顺便说一下,数据集是这样的形式

   (x,     d)

