我在使用keras的expand_dims函数时遇到问题。这是一个简单的示例:
此代码有效:
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims
def add_fun(x):
return tf.add(x[0], x[1])
in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1, 1))
out = Lambda(add_fun)([in_1, in_2])
m = Model([in_1, in_2], out)
并且此代码不:
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims
def add_fun(x):
return tf.add(x[0], x[1])
in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))
problem_part = expand_dims(in_2, axis=1)
out = Lambda(add_fun)([in_1, problem_part])
m = Model([in_1, in_2], out)
如图here所示,我相信我使用的是expand_dims正确,而且我不知道为什么它会引起问题。
答案 0 :(得分:1)
问题在于expand_dims
不是Keras层。如果您改为将调用转到lambda层内的expand_dims
,则它将正常工作。
答案 1 :(得分:1)
这可以通过将expand_dims函数调用包装在Lambda中来解决:
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims
def add_fun(x):
return tf.add(x[0], x[1])
in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))
problem_part = Lambda(lambda x: expand_dims(x, axis=1))(in_2)
out = Lambda(add_fun)([in_1, problem_part])
m = Model([in_1, in_2], out)