Keras的expand_dims函数导致张量丢失元数据

时间:2019-01-25 13:30:37

标签: python tensorflow keras

我在使用keras的expand_dims函数时遇到问题。这是一个简单的示例:

此代码有效:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1, 1))

out = Lambda(add_fun)([in_1, in_2])

m = Model([in_1, in_2], out)

并且此代码不:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))

problem_part = expand_dims(in_2, axis=1)

out = Lambda(add_fun)([in_1, problem_part])

m = Model([in_1, in_2], out)

如图here所示,我相信我使用的是expand_dims正确,而且我不知道为什么它会引起问题。

2 个答案:

答案 0 :(得分:1)

问题在于expand_dims不是Keras层。如果您改为将调用转到lambda层内的expand_dims,则它将正常工作。

答案 1 :(得分:1)

这可以通过将expand_dims函数调用包装在Lambda中来解决:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))

problem_part = Lambda(lambda x: expand_dims(x, axis=1))(in_2)

out = Lambda(add_fun)([in_1, problem_part])

m = Model([in_1, in_2], out)