使用TensorFlow

时间:2016-01-20 14:30:35

标签: python tensorflow

给出灰度图像 I 作为2D张量(尺寸W,H)和坐标 C 的张量(Dim.None,2)。我想将 C 的行解释为 I 中的坐标,使用某种插值在这些坐标处对 I 进行采样(双线性可能会很好)对于我的用例),并将结果值存储在新的Tensor P 中(维度为None,即1维,条目与 C 有行)。< / p>

这是否可以(有效)使用TensorFlow?我能找到的只是用于调整图像大小(等距重新采样,如果你喜欢)的函数。但我无法在坐标列表中找到任何开箱即用的样本。

即。我原本希望找到类似tf.interpolate()函数的东西:

I = tf.placeholder("float", shape=[128, 128])
C = tf.placeholder("float", shape=[None, 2])
P = tf.interpolate(I, C, axis=[0, 1], method="linear")

理想情况下,我会寻找一种解决方案,允许我使用带有形状的 C (无,M)在M维上沿N维张量 I 进行插值并产生一个N-M + 1维输出,如上面代码中的“axis”参数所示。

(我的应用程序中的“图像”不是图片顺便说一下,它是来自物理模型的采样数据(当用作占位符时)或替代学习模型(当用作变量时)。现在这个物理模型有2个自由度,因此在“图像”中插值现在已足够,但我可能会在未来研究更高维度的模型。)

如果使用现有的TensorFlow功能无法实现这样的功能:当我想实现类似这样的tf.interpolate()运算符时,我应该从哪里开始? (文档和/或简单示例代码)

2 个答案:

答案 0 :(得分:10)

没有内置的op执行这种插值,但你应该能够使用现有TensorFlow操作的组合来完成。我建议双线性案例采用以下策略:

  1. 从索引的张量C中,计算与四个角点对应的整数张量。例如(名称假设原点位于左上角):

    top_left = tf.cast(tf.floor(C), tf.int32)
    
    top_right = tf.cast(
        tf.concat(1, [tf.floor(C[:, 0:1]), tf.ceil(C[:, 1:2])]), tf.int32)
    
    bottom_left = tf.cast(
        tf.concat(1, [tf.ceil(C[:, 0:1]), tf.floor(C[:, 1:2])]), tf.int32)
    
    bottom_right = tf.cast(tf.ceil(C), tf.int32)
    
  2. 从表示特定角点的每个张量中,在这些点处从I中提取值向量。例如,对于以下函数,对于2-D情况执行此操作:

    def get_values_at_coordinates(input, coordinates):
      input_as_vector = tf.reshape(input, [-1])
      coordinates_as_indices = (coordinates[:, 0] * tf.shape(input)[1]) + coordinates[:, 1]
      return tf.gather(input_as_vector, coordinates_as_indices)
    
    values_at_top_left = get_values_at_coordinates(I, top_left)
    values_at_top_right = get_values_at_coordinates(I, top_right)
    values_at_bottom_left = get_values_at_coordinates(I, bottom_left)
    values_at_bottom_right = get_values_at_coordinates(I, bottom_right)
    
  3. 首先计算水平方向的插值:

    # Varies between 0.0 and 1.0.
    horizontal_offset = C[:, 0] - tf.cast(top_left[:, 0], tf.float32)
    
    horizontal_interpolated_top = (
        ((1.0 - horizontal_offset) * values_at_top_left)
        + (horizontal_offset * values_at_top_right))
    
    horizontal_interpolated_bottom = (
        ((1.0 - horizontal_offset) * values_at_bottom_left)
        + (horizontal_offset * values_at_bottom_right))
    
  4. 现在计算垂直方向的插值:

    vertical_offset = C[:, 1] - tf.cast(top_left[:, 1], tf.float32)
    
    interpolated_result = (
        ((1.0 - vertical_offset) * horizontal_interpolated_top)
        + (vertical_offset * horizontal_interpolated_bottom))
    

答案 1 :(得分:3)

对于最近的邻居而言,这已经变得棘手了,因为TF还没有Numpy切片的一般性(github issue #206),并且gather仅适用于第一维。但是这里有一种解决方法,使用gather-&gt; transpose-&gt; gather-&gt;提取对角线

def identity_matrix(n):
  """Returns nxn identity matrix."""
  # note, if n is a constant node, this assert node won't be executed,
  # this error will be caught during shape analysis 
  assert_op = tf.Assert(tf.greater(n, 0), ["Matrix size must be positive"])
  with tf.control_dependencies([assert_op]):
    ones = tf.fill(n, 1)
    diag = tf.diag(ones)
  return diag

def extract_diagonal(tensor):
  """Extract diagonal of a square matrix."""

  shape = tf.shape(tensor)
  n = shape[0]
  assert_op = tf.Assert(tf.equal(shape[0], shape[1]), ["Can't get diagonal of "
                                                       "a non-square matrix"])

  with tf.control_dependencies([assert_op]):
    return tf.reduce_sum(tf.mul(tensor, identity_matrix(n)), [0])


# create sample matrix
size=4
I0=np.zeros((size,size), dtype=np.int32)
for i in range(size):
  for j in range(size):
    I0[i, j] = 10*i+j

I = tf.placeholder(dtype=np.int32, shape=(size,size))
C = tf.placeholder(np.int32, shape=[None, 2])
C0 = np.array([[0, 1], [1, 2], [2, 3]])
row_indices = C[:, 0]
col_indices = C[:, 1]

# since gather only supports dim0, have to transpose
I1 = tf.gather(I, row_indices)
I2 = tf.gather(tf.transpose(I1), col_indices)
I3 = extract_diagonal(tf.transpose(I2))

sess = create_session()
print sess.run([I3], feed_dict={I:I0, C:C0})

首先从这样的矩阵开始:

array([[ 0,  1,  2,  3],
       [10, 11, 12, 13],
       [20, 21, 22, 23],
       [30, 31, 32, 33]], dtype=int32)

此代码提取主

上方的对角线
[array([ 1, 12, 23], dtype=int32)]

[]运营商变身SqueezeSlice

会有一些神奇的事情发生

enter image description here