我在Keras中创建了一个自定义层(称为GraphGather),但输出张量打印为:
Tensor(“ graph_gather / Tanh:0”,shape =(?,?),dtype = float32)
由于某种原因,形状返回为(?,?),这将导致下一个致密层引发以下错误:
ValueError:
Dense
的输入的最后维度应定义。找到了None
。
GraphGather层代码如下:
class GraphGather(tf.keras.layers.Layer):
def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
self.batch_size = batch_size
self.num_mols_in_batch = num_mols_in_batch
self.activation_fn = activation_fn
super(GraphGather, self).__init__(**kwargs)
def build(self, input_shape):
super(GraphGather, self).build(input_shape)
def call(self, x, **kwargs):
# some operations (most of def call omitted)
out_tensor = result_of_operations() # this line is pseudo code
if self.activation_fn is not None:
out_tensor = self.activation_fn(out_tensor)
out_tensor = out_tensor
return out_tensor
def compute_output_shape(self, input_shape):
return (self.num_mols_in_batch, 2 * input_shape[0][-1])}
I have also tried hardcoding compute_output_shape to be:
python
def compute_output_shape(self,input_shape):
回报(64,150)
```
然而,在打印时输出张量仍然是
Tensor(“ graph_gather / Tanh:0”,shape =(?,?),dtype = float32)
这将导致上面编写的ValueError。
答案 0 :(得分:3)
我有同样的问题。我的解决方法是将以下行添加到call方法中:
input_shape = tf.shape(x)
然后:
return tf.reshape(out_tensor, self.compute_output_shape(input_shape))
我还没有遇到任何问题。