使用Tensorflow后端时,如何为Keras的SGD优化器的变体打印中间状态

时间:2019-05-21 16:07:04

标签: python tensorflow keras

我想写一个Keras SGD的变体,它允许在指定的迭代中步长大小的离散变化。我正在使用Tensorflow后端。

为了帮助调试,我正在尝试使优化器的get_updates方法向我显示消息,但是我似乎做不到。我已经尝试了标准打印语句和tf.Print,但是都没有用。来自Keras SGD优化器cldass的相关代码看起来像这样:

@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
    print (" -------------------------> Getting updates <------------------------------------------")
    grads = self.get_gradients(loss, params)
    self.updates = [K.update_add(self.iterations, 1)]
    tf.Print(self.iterations,
             [self.iterations],
             message="-------------------------------> GETTING UPDATES <----------------------------------------")

    lr = self.lr
    if self.initial_decay > 0:
        lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
                                                  K.dtype(self.decay))))

    # momentum
    shapes = [K.int_shape(p) for p in params]
    moments = [K.zeros(shape) for shape in shapes]
    self.weights = [self.iterations] + moments
    for p, g, m in zip(params, grads, moments):
        v = self.momentum * m - lr * g  # velocity
        self.updates.append(K.update(m, v))

        if self.nesterov:
            new_p = p + self.momentum * v - lr * g
        else:
            new_p = p + v

        # Apply constraints.
        if getattr(p, 'constraint', None) is not None:
            new_p = p.constraint(new_p)

        self.updates.append(K.update(p, new_p))
    return self.updates

尽管print语句确实能够中继一条消息,但仅此而已。我希望每次参数更新时都可以看到输出(即每次批处理之后)。相反,我只会在训练的第一个时期之前看到打印的输出。

我在做什么错?我是否还在为图计算工作困扰(对我来说)?另外,tf.Print是否应该产生一些文本输出?

2 个答案:

答案 0 :(得分:0)

我想我知道这里发生了什么.....

  1. 我的import timeit testdf = pd.concat([testdf for i in range(10000)], ignore_index=True) def fix_df(): global testdf testdf["Charge"] = testdf["Rev"].where(testdf.Pnum.isin(pnumlist), 0) for service in servicelist: testdf["{}count".format(service)] = ( testdf["Service"].str.contains(service).astype(int) ) return testdf def fix_df_orig(): global testdf def rowhandler(testdfrow: tuple) -> tuple: testdfrow["Charge"] = testdfrow["Rev"] if testdfrow["Pnum"] in pnumlist else 0 for service in servicelist: testdfrow["{}count".format(service)] = ( 1 if service in testdfrow["Service"] else 0 ) return testdfrow newcolslist = ["Charge"] newcolsdict = {col: 0 for col in newcolslist} testdf = testdf.assign(**newcolsdict) # pre-allocating memory speeds up program testdf = testdf.apply(rowhandler, axis=1) In [1]: timeit.timeit(fix_df, number=1) Out[1]: 0.06966943305451423 In [2]: timeit.timeit(fix_df_orig, number=1) Out[2]: 109.82892861706205 语句仅在调用print时产生输出。它仅被调用一次,并返回用于实际计算更新的图(子图?)。

  2. 我的get_updates不产生任何输出,因为我从未明确将其放在计算图中

答案 1 :(得分:0)

要在图形模式下使用List<D extends BasicDao>,您可以仅使用tf.print来替代tf.print,只需强制执行{{1 }}在tf.Print中执行张量之前的操作。您可以检查here了解详情。