批量大小如何影响列表中张量的形状?

时间:2018-07-09 02:39:06

标签: python tensorflow

我正在尝试在2个张量之间执行矩阵乘法。 有一张张量清单,其占位符已初始化为:

self.sample_set = [tf.placeholder(tf.int32, [None,self.max_input_right])
      for _ in xrange(7)]

现在,这个张量列表已被单词嵌入,并通过了卷积层。之后,我将通过以下方式进行集中池:

for acnn in self.a_cnn:
        self.q_clust_pooling,self.a_clust_pooling = self.attentive_pooling(ques_feature,acnn)

其中self.a_cnn是卷积张量的列表。 q_clust_pooling的形状为[?,57,512],a_clust_pooling的形状为[?,1490,512],其中512是过滤器总数,57是ques的长度,1490是acnn的长度。

在attentive_pooling方法中,按以下方式计算矩阵乘法时,会导致结果错误:

def attentive_pooling(self,input_left, input_right):
    Q = tf.reshape(input_left,[-1,self.max_input_left,int(len(self.filter_sizes) * self.num_filters)],name = 'Q')
    A = tf.reshape(input_right,[-1,self.max_input_right,int(len(self.filter_sizes) * self.num_filters)],name = 'A')
    filter_size = int(len(self.filter_sizes) * self.num_filters)
    first = tf.matmul(tf.reshape(Q,[-1,filter_size]),self.U)
    second_step = tf.reshape(first,[-1,self.max_input_left,filter_size])        
    result = tf.matmul(second_step,tf.transpose(A,perm = [0,2,1]))

在此阶段,当批处理大小为7时,将数据输入其中时会出现以下错误:

Traceback (most recent call last):
File "workspace-2/Reddit/ExisQACNN/train.py", line 232, in <module>
test_pair_wise()
File "workspace-2/Reddit/ExisQACNN/train.py", line 175, in test_pair_wise
feed_dict)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 789, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 997, in _run
feed_dict_string, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1132, in _do_run
target_list, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0].dim(0) and In[1].dim(0) must be the same: [8,57,512] vs [10,512,1490]
 [[Node: score/MatMul_21 = BatchMatMul[T=DT_FLOAT, adj_x=false, adj_y=false, _device="/job:localhost/replica:0/task:0/cpu:0"](score/Reshape_26, score/transpose_10)]]

Caused by op u'score/MatMul_21', defined at:
File "workspace-2/Reddit/ExisQACNN/train.py", line 232, in <module>
test_pair_wise()
File "workspace-2/Reddit/ExisQACNN/train.py", line 113, in test_pair_wise
cnn.build_graph()
File "/home/purbasha/workspace-2/Reddit/ExisQACNN/QA_CNN_pairwise.py", line 409, in build_graph
self.scoring()
File "/home/purbasha/workspace-2/Reddit/ExisQACNN/QA_CNN_pairwise.py", line 147, in scoring
self.score12 = self.pair_ans(self.q_pos_feature_map,self.a_pos_pooling)
File "/home/purbasha/workspace-2/Reddit/ExisQACNN/QA_CNN_pairwise.py", line 112, in pair_ans
self.q_clust_pooling,self.a_clust_pooling = 
self.attentive_pooling(ques_feature,acnn)
File "/home/purbasha/workspace-2/Reddit/ExisQACNN/QA_CNN_pairwise.py", line 230, in attentive_pooling
result = tf.matmul(second_step,tf.transpose(A,perm = [0,2,1]))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py", line 1786, in matmul
a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_math_ops.py", line 290, in _batch_mat_mul
adj_y=adj_y, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): In[0].dim(0) and In[1].dim(0) must be the same: [7,57,512] vs [8,512,1490]
 [[Node: score/MatMul_21 = BatchMatMul[T=DT_FLOAT, adj_x=false, adj_y=false, _device="/job:localhost/replica:0/task:0/cpu:0"](score/Reshape_26, score/transpose_10)]]

在这里,我知道7是批处理大小,这使ques_clust的形状为[7,57,512]。但是,我无法理解acnn的转置如何得到形状[8,512,1490]。有人可以提供一些建议吗?

注意:当我将batch_size更改为8时,错误保持原样,ques_clust的形状变为[8,57,512],而acnn的转置为[10,512,1490]。

0 个答案:

没有答案