如何在python中修改生成器的最后一个元素?

时间:2018-07-10 16:28:35

标签: python machine-learning generator conv-neural-network imperative-programming

我有一个生成器,我想修改生成器的最后一个元素。 我想用另一个元素替换最后一个元素。我知道如何检索最后一个元素,但不知道如何修改它。

解决这个问题的最佳方法是什么?

有关更多背景信息,这是我想要做的:

for child in alexnet.children():
    for children_of_child in child.children():
         print(children_of_child);

我的生成器对象是:children_of_child,第二个孩子的所有孩子都是:

Dropout(p=0.5)
Linear(in_features=9216, out_features=4096, bias=True)
ReLU(inplace)
Dropout(p=0.5)
Linear(in_features=4096, out_features=4096, bias=True)
ReLU(inplace)
Linear(in_features=4096, out_features=1000, bias=True)

我想用自己的回归网替换最后一层Linear(in_features=4096, out_features=1000, bias=True)。 `

2 个答案:

答案 0 :(得分:1)

由于您使用的列表较小(即使就RAM而言,即使ResNet-150也“相当小”),所以我将使其易于理解和维护。没有“明显”的方法可以检测到您距离耗尽发电机仅一步之遥。

  1. 耗尽当前生成器,并列出其输出。
  2. 根据需要替换最后一个元素。
  3. 在此更改后的列表周围包装新的生成器。

执行此操作的“好”(?)方法是在原始语言中编写一个具有一个元素前瞻性的包装器生成器:在每次调用N时,您已经拥有包装器中的元素N。您从“真实”生成器(您发布的代码)中获取元素N+1。如果该元素存在,则通常返回元素N。如果该生成器已用尽,则将最后一个元素替换为所需的元素,然后返回更改。

示例

为简单起见,我使用range代替了原始生成器。

def new_tail():
    my_list = list(range(6))
    my_list[-1] = "new last element"
    for elem in my_list:
        yield elem

for item in new_tail():
    print(item)

输出:

0
1
2
3
4
new last element

有帮助吗?

答案 1 :(得分:1)

执行此操作的方法是迭代一个步骤,并在运行时跟踪先前的值。对于每个值,请产生前一个值。当您结束时,不产生最后一个先前的值,而是产生替换值:

def new_tail(it, tail):
    sentinel = prev = object()
    for value in it:
        if prev is not sentinel:
            yield prev
        prev = value
    yield tail

或者您可以特别对待第一个元素,而不使用哨兵:

def new_tail(it, tail):
    it = iter(it)
    prev = next(it)
    for value in it:
        yield prev
        prev = value
    yield tail

您可能想考虑使用完全空的迭代器会发生什么。我不确定是要产生任何结果,产生替换值还是引发异常。第一个版本产生替换值。第二个……好吧,它应该引发一个异常,但是从3.7开始,它发出DeprecationWarning并且不产生任何结果,这可能不是您想要的行为。

无论如何,您可以将next与默认值sentinel一起使用,也可以将except StopIteration:next一起使用。然后,轻松执行所需的三个操作即可。


但是,如果您更抽象地考虑它,则可以使它更简单:如果您具有所有相邻的元素对,那么每个这样的对中的第一个元素将为您提供除最后一个元素之外的所有元素。因此,使用the pairwise recipe from the itertools docs

def new_tail(it, tail):
    for x, _ in pairwise(it):
        yield x
    yield tail

或者,如果您愿意,甚至可以使用itertools.chainoperator.itemgetter使其成为单个表达式,尽管这可能有点愚蠢:

def new_tail(it, tail):
    return chain(map(itemgetter(0), pairwise(it)), (tail,))