具有部分已知的segment_ids形状的Tensorflow unsorted_segment_mean

时间:2018-08-13 17:36:17

标签: python tensorflow

我不知道unsorted_segment_mean的细分ID形状,也找不到其他选择。 在documentation of unsorted_segment_sum (same behavior)中写为:

  

segment_ids:张量。必须是以下类型之一:int32,int64。张量,其形状是data.shape的前缀。

“形状是data.shape的前缀”是什么意思? 我试图动态设置形状,但是不起作用:

a, b = 2, 2 # Will be some input variables                                                                                                                                                                                                                                                                                                                                                                                        
x = tf.placeholder(tf.float32) 
# Add padding                                                                                                                                                        
size = tf.shape(x)                                                                                                                                                                                      
batchSize, inputSize = size[0], size[1]                                                                                                                                                                                           
paddingLength=tf.cast((inputSize/a)%b, dtype=tf.int32)                                                                                                                                   
T_paddingSize = tf.scatter_nd([[1,1]], [paddingLength], [2,2])
x = tf.pad(x, T_paddingSize, 'SYMMETRIC')
# Generate segment IDs                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
size = tf.shape(x)                                                                                                                                                                                            
batchSize, inputSize = size[0], size[1]
T_segSize = tf.cast(tf.ceil([a/b]), dtype=tf.int32)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
T_segments = tf.tile(tf.range(0, a), T_segSize)
# first TRY: ValueError: Cannot convert a partially known TensorShape to a Tensor: (?,)                                                                                                                                                    
T_segments = tf.reshape(T_segments, T_segSize*a)
# second TRY: TypeError: Tensor objects are not iterable when eager execution is not enabled.                                                                                                                                                  
T_segments = T_segments.set_shape(T_segSize*a)                                                                                                                                     
# Shall compute the mean over unsorted segments 
# ValueError: Cannot convert a partially known TensorShape to a Tensor: (?,)                                                                                                                                                                                             
y = tf.matrix_transpose(tf.unsorted_segment_mean(tf.matrix_transpose(x), T_segments, T_segSize[0]))  

我该怎么办? 感谢您的帮助!

0 个答案:

没有答案