采样张量流

时间:2018-06-18 01:43:13

标签: python tensorflow

我使用以下功能在每个图像中采样点。如果batch_size为None,则tf.range给出错误。我如何在tensorflow中进行采样

def sampling(binary_selection,num_points, points):
  """
      binary_selection: tensor of size (batch_size, points) 
          with values 1.0 or 0.0. Indicating positive and negative points. 
          We want to sample num_points from positive points of each image
      points: tensor of size (batch_size, num_points_in_image)
      num_points: number of points to sample for each image
  """
  batch_size = points.get_shape()[0]
  indices = tf.multinomial((tf.log(binary_selection)), num_points)
  indices = tf.cast(tf.expand_dims(indices, axis=2), tf.int32)
  batch_seq = tf.expand_dims(tf.range(batch_size), axis=1) 
  im_indices = tf.expand_dims(tf.tile(batch_seq, [1, num_points]), axis=2) 
  indices = tf.concat([im_indices, indices], axis=2)
  return tf.gather_nd(points, indices)

我收到以下错误

_dimension_tensor_conversion_function raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) ValueError: Cannot convert an unknown Dimension to a Tensor: ?

在测试和训练期间,我将使用batch_size一个整数,但是当我初始化时,我想将None作为输入,以便在测试和训练时间内改变批量大小。

2 个答案:

答案 0 :(得分:1)

您需要为batch_size提供一个值。

需要初始化。

目前,它没有任何价值。

答案 1 :(得分:0)

batch_size = points.get_shape()[0]更改为batch_size = tf.shape(points)[0]。要了解静态和动态形状,请检查:How to understand static shape and dynamic shape in TensorFlow?