在python3中:list(iterables)的奇怪行为

时间:2019-11-28 14:56:37

标签: python list dictionary pytorch iterable

我有一个关于python中的可迭代行为的特定问题。我可迭代的是pytorch中的一个自定义构建的Dataset类:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

当我初始化那个datasetTest类的特定实例时,就会出现怪异的行为。根据我作为参数X传递的数据结构,当我调用list(datasetTestInstance)时,它的行为会有所不同。特别是,当传递torch.tensor作为参数时,没有问题,但是,当传递dict作为参数时,它将抛出KeyError。这样做的原因是list(iterable)不仅调用i = 0,...,len(iterable)-1,而且还调用i = 0,...,len(iterable)。也就是说,它将迭代直到(包括)索引等于迭代对象的长度。显然,此索引未在任何python数据结构中定义,因为最后一个元素始终具有索引len(datastructure)-1而不是len(datastructure)。如果X是torch.tensor或列表,即使我认为应该是一个错误,也不会出现任何错误。即使对于索引为len(datasetTestinstance)的(不存在)元素,它仍将调用getitem,但它不会计算y = self.X [len(datasetTestInstance]。 >

当将dict作为数据传递时,当x = len(datasetTestInstance)时,它将在最后一次迭代中引发错误。我猜这实际上是预期的行为。但是,为什么只对字典而不对列表或torch.tensor发生这种情况?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

如果您想更好地理解我所观察到的内容,可以尝试使用该代码段。

我的问题是:

1。)为什么列表(可迭代)迭代直到(包括)len(可迭代)? for循环不会那样做。

2。)如果有torch.tensor或作为数据X传递的列表:为什么即使为索引len(datasetTestInstance)调用getitem方法也不会抛出错误,因为它实际上应该超出范围没有定义为张量/列表中的索引?或者换句话说,当到达索引len(datasetTestInstance)然后进入 getitem 方法时,究竟发生了什么?它显然不再调用'y = self.X [x]'(否则会出现IndexError),但是它确实进入了getitem方法,我可以看到它从getitem方法中打印索引x。那么该方法会发生什么呢?为何它的行为取决于是否有torch.tensor / list或dict?

2 个答案:

答案 0 :(得分:1)

这实际上不是pytorch的特定问题,这是一个通用的python问题。

您正在使用list(iterable)构建列表,其中iterable类是实现sequence semantics的类。

在这里查看__getitem__对序列类型的预期行为(最相关的部分以粗体显示)

  

object.__getitem__(self, key)

     

呼吁实施评估   self[key]。对于序列类型,可接受的密钥应为整数   和切片对象。注意负号的特殊解释   索引(如果类希望模拟序列类型)取决于   __getitem__()方法。如果密钥的类型不合适,则可能引发TypeError如果该索引的索引集之外的值   顺序(在对负值进行任何特殊解释之后),   IndexError应该升高。对于映射类型,如果缺少键(不是   在容器中),应该引发KeyError。

     

注意:for循环期望将IndexError引发非法   索引以正确检测序列的结尾。

这里的问题是,对于使用无效索引调用IndexError的情况,对于序列类型,python需要一个__getitem__。看来list构造函数依赖于此行为。在您的示例中,当X是dict时,尝试访问无效的密钥会导致__getitem__引发KeyError,这不是我们所期望的,因此不会被捕获,从而导致列表的构造失败。


根据此信息,您可以执行以下操作

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

我不建议您在实践中这样做,因为它依赖于您的字典X仅包含整数键01,...,len(X)-1意味着在大多数情况下,它最终的行为就像一个列表,因此,最好只使用一个列表。

答案 1 :(得分:1)

一堆有用的链接:

  1. [Python 3.Docs]: Data model - Emulating container types
  2. [Python 3.Docs]: Built-in Types - Iterator Types
  3. [Python 3.Docs]: Built-in Functions - iter(object[, sentinel])
  4. [SO]: Why does list ask about __len__?(所有答案)

关键是 list 构造函数使用(可迭代)参数的 __ len __ ((如果提供)来计算新的容器长度), (通过迭代器协议)对其进行迭代。

由于可怕的巧合(请记住 dict 支持迭代器协议),因此您的示例以这种方式工作(迭代了所有键,但未能达到等于字典长度的键) ,并且发生在其键(这是一个序列)上:

  • 您的词典仅具有 int (以及更多)
  • 它们的值与其索引相同(按顺序)

更改以上两个项目符号所表示的任何条件,都会使实际错误更加雄辩。

两个对象(张量 dict list )都支持迭代器协议。为了使一切正常,您应该将其包装在您的 Dataset 类中,并稍微调整一种映射类型(使用值而不是键)。
代码(与 key_func 相关的部分)有点复杂,但是只是为了易于配置(如果要更改某些内容-用于 demo 的目的)。< / p>

code00.py

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

输出

[cfati@CFATI-5510-0:e:\Work\Dev\StackOverflow\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32


Inner object: tensor([[ 0.6626,  0.1107],
        [-0.1118,  0.6177]])
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(int))
    tensor([0.6626, 0.1107])
  0: tensor([0.6626, 0.1107])
    __getitem__(1(int))
    tensor([-0.1118,  0.6177])
  1: tensor([-0.1118,  0.6177])

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  tensor([0.6626, 0.1107])
    __next__()
  tensor([-0.1118,  0.6177])
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [tensor([0.6626, 0.1107]), tensor([-0.1118,  0.6177])]


Inner object: {'1': 86, '0': 25}
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(str))
    25
  0: 25
    __getitem__(1(str))
    86
  1: 86

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  86
    __next__()
  25
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [86, 25]


Done.

您可能还需要检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET IterableDataset )。