Tensorflow获得输入大小的百分比

时间:2018-01-12 11:14:50

标签: python tensorflow

我有一个解决这个问题的工作方案,但看起来非常冗长。我想知道"适当的"实现这一目标的方法是......

假设我有一个尺寸为[?,5,...]的张量a(此处的尾随尺寸不重要)。在运行时,第一维可以是任何东西,100,200等,具体取决于feed dict传递的数据集的大小。

在内部,我需要将数据集大小的百分比(第一维)作为int。在普通的numpy代码中,它看起来像:

def percentOfInputSize( data, percent = .5 ):
    return int( round( data.shape[0] * percent ) )

在tensorflow中,我实现这一目标的唯一方法是

def percentOfInputSize( data, percent = .5 ):
   rows = tf.shape(data)[0]
   return tf.cast( tf.round( tf.cast(rows, tf.float32) * percent ), tf.int32)

但这看起来很可怕。有没有更好的方法来计算这个值?

0 个答案:

没有答案