Keras自定义层错误(操作IsVariableInitialized已被标记为不可获取)

时间:2018-02-03 01:09:54

标签: python-3.x tensorflow keras

我试图在玩具数据集上创建自定义Keras图层,但我遇到了问题。在高层次上,我想创建一个"输入门"图层,它具有可训练的权重,可以打开或关闭每列输入。因此,我首先尝试将输入乘以学习权重的sigmoid< d'd版本。我的代码如下:

### This is my custom layer
class InputGate(Layer):
    def __init__(self, **kwargs):
        super(InputGate, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(name='input_gate',
                                      shape=input_shape[1:],
                                      initializer='random_uniform',
                                      trainable=True)

        super(InputGate, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, inputs):
        gate_amount = K.sigmoid(self.kernel)
        return inputs * gate_amount

    def get_config(self):
        config = {}
        base_config = super(InputGate, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

def create_linear_model(x, y, num_noise_vars = 0, reg_strength=0):
    new_x = get_x_with_noise(x, num_noise_vars=num_noise_vars)
    model = Sequential([
        InputGate(input_shape=(1+num_noise_vars,)),
        Dense(1, kernel_regularizer=l2(reg_strength))
    ])
    model.compile(optimizer="rmsprop", loss="mse")
    model.optimizer.lr = 0.001
    return {"model": model, "new_x": new_x}

def get_x_with_noise(x, num_noise_vars):
    noise_vars = []
    for noise_var in range(num_noise_vars):
        noise_vars.append(np.random.random(len(x)))
    noise_vars.append(x)
    x_with_noise = noise_vars
    new_x = np.array(list(zip(*x_with_noise)))
    return new_x

x = np.random.random(500)
y = (x * 3) + 10
num_noise_vars = 5
info = create_linear_model(x, y, num_noise_vars=num_noise_vars)
model = info["model"]
new_x = info["new_x"]
results = model.fit(new_x, y, epochs=num_epochs, verbose=0)

然后我收到以下错误: ValueError: Operation 'input_gate_14/IsVariableInitialized' has been marked as not fetchable.

此图层主要取自文档(https://keras.io/layers/writing-your-own-keras-layers/)。我使用Keras 2.0.9,在CPU上使用Tensorflow后端(Macbook Air)。

这一层看起来很简单,而且谷歌搜索错误会让我进行一些看似不相关的讨论。任何人都有关于导致这种情况的想法吗?

非常感谢任何帮助!谢谢!

0 个答案:

没有答案