在运行时从keras模型获取学习率

时间:2020-01-27 17:46:21

标签: python tensorflow keras deep-learning

我有一个非常简单的任务,我们在此编写这样的keras模型的体系结构:

    def build_model_GRU(self):
        model = Sequential()
        model.add(GRU(75, activation='relu', return_sequences=True, input_shape=(self.nb_timesteps, self.nb_features)))
        model.add(GRU(30, activation='relu', return_sequences=True))
        model.add(GRU(30, activation='relu'))
        model.add(Dense(1))
        model.compile(optimizer='adam', loss='mse')

        return model

然后,我将返回的模型传递给我在下面定义的函数:

 def get_model_data(self, model):
        ''' prints some details about the architecture of the used model '''
        print('Number of layers: %d' % len(model.layers)) # number of layers used

        # the name of the layers used:
        layers_str = ''
        for i in range(len(model.layers)):
            if i == len(model.layers) - 1:
                layers_str += '%s' % model.layers[i].name
            else:
                layers_str += '%s' % model.layers[i].name + '->'
        print('Layers: %s' % layers_str)
        print('Learning Rate: %.5f' % K.eval(model.optimizer.lr)) # the learning rate used
        print('Decay: %.5f' % K.eval(model.optimizer.decay))
        print('Optimizer: %.5f' % model.optimizer.get_config()['name'])

但是我仍然遇到以下错误

回溯(最近通话最近): _do_call中的文件“ C:\ Users \ 96171 \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,行1334 返回fn(* args) _run_fn中的文件“ C:\ Users \ 96171 \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,行1319 选项,feed_dict,fetch_list,target_list,run_metadata) 文件“ C:\ Users \ 96171 \ AppData \ Local \ Programs \ Python \ Python36 \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,行1407,位于_call_tf_sessionrun中 run_metadata) tensorflow.python.framework.errors_impl.FailedPreconditionError:从容器:本地主机读取资源变量Adam / lr时出错。这可能意味着该变量未初始化。找不到:容器lo​​calhost不存在。 (找不到资源:localhost / Adam / lr) [[{{node Adam / lr / Read / ReadVariableOp}} = ReadVariableOpdtype = DT_FLOAT,_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]]

我搜索了此错误,否则他们建议使用keras会话。我认为对于我来说,它更简单,因为我是在运行时执行此操作的,并且我没有保存模型以进行加载,如下所述:here

0 个答案:

没有答案