张量流中的二进制搜索和插值

时间:2017-08-25 15:09:53

标签: python tensorflow

我试图在张量流中插入一维张量(我实际上想要相当于np.interp)。由于我无法找到类似的张量流op,我不得不自己执行插值。

第一步是在y值中搜索相应索引的x值的排序列表,即执行二分搜索。我尝试使用while循环,但我得到了一个神秘的运行时错误。这是一些代码:

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
query = tf.placeholder(tf.float32, name='query')

with tf.name_scope("binsearch"):
    up   = tf.Variable(0, dtype=tf.int32, name='up')
    mid  = tf.Variable(0, dtype=tf.int32, name='mid')
    down = tf.Variable(0, dtype=tf.int32, name='down')
    done = tf.Variable(-1, dtype=tf.int32, name='done')           

    def cond(up, down, mid, done):
        return tf.logical_and(done<0,up-down>1)

    def body(up, down, mid, done):
        val  = tf.gather(xaxis, mid)
        done = tf.cond(val>query, 
                       tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1), 
                       tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1) )
        up = tf.cond(val>query, lambda: mid, lambda: up )
        down = tf.cond(val<query, lambda: mid, lambda: down )

        with tf.control_dependencies([done, up, down]):
            return up, down, (up+down)//2, done

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))

这导致

AttributeError: 'int' object has no attribute 'name'

我在Windows 7上使用Python 3.6,在支持gpu时使用tensorflow 1.1。知道什么是错的吗? 感谢。

这里是完整的堆栈跟踪:

AttributeError                            Traceback (most recent call last)
<ipython-input-185-693d3873919c> in <module>()
     19             return up, down, (up+down)//2, done
     20 
---> 21     up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name)
   2621     context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
   2622     ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)
-> 2623     result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
   2624     return result
   2625 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
   2454       self.Enter()
   2455       original_body_result, exit_vars = self._BuildLoop(
-> 2456           pred, body, original_loop_vars, loop_vars, shape_invariants)
   2457     finally:
   2458       self.Exit()

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
   2404         structure=original_loop_vars,
   2405         flat_sequence=vars_for_body_with_tensor_arrays)
-> 2406     body_result = body(*packed_vars_for_body)
   2407     if not nest.is_sequence(body_result):
   2408       body_result = [body_result]

<ipython-input-185-693d3873919c> in body(up, down, mid, done)
     11         val  = tf.gather(xaxis, mid)
     12         done = tf.cond(val>query, 
---> 13                        tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1),
     14                        tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1) )
     15         up = tf.cond(val>query, lambda: mid, lambda: up )

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, fn1, fn2, name)
   1746     context_f = CondContext(pred, pivot_2, branch=0)
   1747     context_f.Enter()
-> 1748     _, res_f = context_f.BuildCondBranch(fn2)
   1749     context_f.ExitResult(res_f)
   1750     context_f.Exit()

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn)
   1666               real_v = sparse_tensor.SparseTensor(indices, values, dense_shape)
   1667           else:
-> 1668             real_v = self._ProcessOutputTensor(v)
   1669         result.append(real_v)
   1670     return original_r, result

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _ProcessOutputTensor(self, val)
   1624     """Process an output tensor of a conditional branch."""
   1625     real_val = val
-> 1626     if val.name not in self._values:
   1627       # Handle the special case of lambda: x
   1628       self._values.add(val.name)

AttributeError: 'int' object has no attribute 'name'

2 个答案:

答案 0 :(得分:2)

我不知道您的错误来源,但我可以告诉您,tf.while_loop很可能非常慢。您可以实现没有循环的线性插值,如下所示:

import numpy as np
import tensorflow as tf

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
query = tf.placeholder(tf.float32, name='query')

# Add additional elements at the beginning and end for extrapolation
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)

# Find the index of the interval containing query
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
diff = cmp[1:] - cmp[:-1]
idx = tf.argmin(diff)

# Interpolate
alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]

# Test with f(x) = 2 * x
q = 5.4
x = np.arange(100)
y = 2 * x
with tf.Session() as sess:
    q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})
print(q_interp)
>>> 10.8

填充部分只是为了避免在超出范围时传递值时出现问题,但除此之外只需要比较和查找值开始大于query的位置。

答案 1 :(得分:0)

发现问题 - tensorflow不喜欢python整数作为cond的参数 - 它需要首先包装在常量中。此代码有效:

with tf.name_scope("binsearch"):
    m_one = tf.constant(-1, dtype=tf.int32, name='minus_one')
    up    = tf.Variable(0, dtype=tf.int32, name='up')
    mid   = tf.Variable(0, dtype=tf.int32, name='mid')
    down  = tf.Variable(0, dtype=tf.int32, name='down')
    done  = tf.Variable(-1, dtype=tf.int32, name='done')

    def cond(up, down, mid, done):
        return tf.logical_and(done<0,up-down>1)

    def body(up, down, mid, done):

        def fn1():
            return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: m_one)

        def fn2():
            return up, mid, (up+mid)//2, tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: m_one)

        return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2 )

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))