Tensorflowcompute_output_shape()不适用于自定义层

时间:2018-06-25 17:24:32

标签: tensorflow

我在Keras中创建了一个自定义层(称为GraphGather),但输出张量打印为:

  

Tensor(“ graph_gather / Tanh:0”,shape =(?,?),dtype = float32)

由于某种原因,形状返回为(?,?),这将导致下一个致密层引发以下错误:

  

ValueError:Dense的输入的最后维度应定义。找到了None

GraphGather层代码如下:

class GraphGather(tf.keras.layers.Layer):

  def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.num_mols_in_batch = num_mols_in_batch
    self.activation_fn = activation_fn
    super(GraphGather, self).__init__(**kwargs)

  def build(self, input_shape):
    super(GraphGather, self).build(input_shape)

 def call(self, x, **kwargs):
    # some operations (most of def call omitted)
    out_tensor = result_of_operations() # this line is pseudo code
    if self.activation_fn is not None:
      out_tensor = self.activation_fn(out_tensor)
    out_tensor = out_tensor
    return out_tensor

  def compute_output_shape(self, input_shape):
    return (self.num_mols_in_batch, 2 * input_shape[0][-1])}

I have also tried hardcoding compute_output_shape to be: python def compute_output_shape(self,input_shape):     回报(64,150) ``` 然而,在打印时输出张量仍然是

  

Tensor(“ graph_gather / Tanh:0”,shape =(?,?),dtype = float32)

这将导致上面编写的ValueError。


系统信息

  • 已编写了自定义代码
  • ** OS平台和发行版*:Linux Ubuntu 16.04
  • TensorFlow版本(使用下面的命令):1.5.0
  • Python版本:3.5.5

1 个答案:

答案 0 :(得分:3)

我有同样的问题。我的解决方法是将以下行添加到call方法中:

input_shape = tf.shape(x)

然后:

return tf.reshape(out_tensor, self.compute_output_shape(input_shape))

我还没有遇到任何问题。