我有一个关于python中的可迭代行为的特定问题。我可迭代的是pytorch中的一个自定义构建的Dataset类:
import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
def __init__(self, X):
self.X = X
def __len__(self):
return len(self.X)
def __getitem__(self, x):
print('***********')
print('getitem x = ', x)
print('###########')
y = self.X[x]
print('getitem y = ', y)
return y
当我初始化那个datasetTest类的特定实例时,就会出现怪异的行为。根据我作为参数X传递的数据结构,当我调用list(datasetTestInstance)时,它的行为会有所不同。特别是,当传递torch.tensor作为参数时,没有问题,但是,当传递dict作为参数时,它将抛出KeyError。这样做的原因是list(iterable)不仅调用i = 0,...,len(iterable)-1,而且还调用i = 0,...,len(iterable)。也就是说,它将迭代直到(包括)索引等于迭代对象的长度。显然,此索引未在任何python数据结构中定义,因为最后一个元素始终具有索引len(datastructure)-1而不是len(datastructure)。如果X是torch.tensor或列表,即使我认为应该是一个错误,也不会出现任何错误。即使对于索引为len(datasetTestinstance)的(不存在)元素,它仍将调用getitem,但它不会计算y = self.X [len(datasetTestInstance]。 >
当将dict作为数据传递时,当x = len(datasetTestInstance)时,它将在最后一次迭代中引发错误。我猜这实际上是预期的行为。但是,为什么只对字典而不对列表或torch.tensor发生这种情况?
if __name__ == "__main__":
a = datasetTest(torch.randn(5,2))
print(len(a))
print('++++++++++++')
for i in range(len(a)):
print(i)
print(a[i])
print('++++++++++++')
print(list(a))
print('++++++++++++')
b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
print(len(b))
print('++++++++++++')
for i in range(len(b)):
print(i)
print(b[i])
print('++++++++++++')
print(list(b))
如果您想更好地理解我所观察到的内容,可以尝试使用该代码段。
我的问题是:
1。)为什么列表(可迭代)迭代直到(包括)len(可迭代)? for循环不会那样做。
2。)如果有torch.tensor或作为数据X传递的列表:为什么即使为索引len(datasetTestInstance)调用getitem方法也不会抛出错误,因为它实际上应该超出范围没有定义为张量/列表中的索引?或者换句话说,当到达索引len(datasetTestInstance)然后进入 getitem 方法时,究竟发生了什么?它显然不再调用'y = self.X [x]'(否则会出现IndexError),但是它确实进入了getitem方法,我可以看到它从getitem方法中打印索引x。那么该方法会发生什么呢?为何它的行为取决于是否有torch.tensor / list或dict?
答案 0 :(得分:1)
这实际上不是pytorch的特定问题,这是一个通用的python问题。
您正在使用list(iterable)构建列表,其中iterable类是实现sequence semantics的类。
在这里查看__getitem__
对序列类型的预期行为(最相关的部分以粗体显示)
object.__getitem__(self, key)
呼吁实施评估
self[key]
。对于序列类型,可接受的密钥应为整数 和切片对象。注意负号的特殊解释 索引(如果类希望模拟序列类型)取决于__getitem__()
方法。如果密钥的类型不合适,则可能引发TypeError
; 如果该索引的索引集之外的值 顺序(在对负值进行任何特殊解释之后),IndexError
应该升高。对于映射类型,如果缺少键(不是 在容器中),应该引发KeyError。注意:
for
循环期望将IndexError
引发非法 索引以正确检测序列的结尾。
这里的问题是,对于使用无效索引调用IndexError
的情况,对于序列类型,python需要一个__getitem__
。看来list
构造函数依赖于此行为。在您的示例中,当X
是dict时,尝试访问无效的密钥会导致__getitem__
引发KeyError
,这不是我们所期望的,因此不会被捕获,从而导致列表的构造失败。
根据此信息,您可以执行以下操作
class datasetTest:
def __init__(self):
self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if index < 0 or index >= len(self):
raise IndexError
return self.X[index]
d = datasetTest()
print(list(d))
我不建议您在实践中这样做,因为它依赖于您的字典X
仅包含整数键0
,1
,...,len(X)-1
意味着在大多数情况下,它最终的行为就像一个列表,因此,最好只使用一个列表。
答案 1 :(得分:1)
一堆有用的链接:
关键是 list 构造函数使用(可迭代)参数的 __ len __ ((如果提供)来计算新的容器长度), (通过迭代器协议)对其进行迭代。
由于可怕的巧合(请记住 dict 支持迭代器协议),因此您的示例以这种方式工作(迭代了所有键,但未能达到等于字典长度的键) ,并且发生在其键(这是一个序列)上:
更改以上两个项目符号所表示的任何条件,都会使实际错误更加雄辩。
两个对象(张量的 dict 和 list )都支持迭代器协议。为了使一切正常,您应该将其包装在您的 Dataset 类中,并稍微调整一种映射类型(使用值而不是键)。
代码(与 key_func 相关的部分)有点复杂,但是只是为了易于配置(如果要更改某些内容-用于 demo 的目的)。< / p>
code00.py :
#!/usr/bin/env python3
import sys
import torch
from torch.utils.data import Dataset
from random import randint
class SimpleDataset(Dataset):
def __init__(self, x):
self.__iter = None
self.x = x
def __len__(self):
print(" __len__()")
return len(self.x)
def __getitem__(self, key):
print(" __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
try:
val = self.x[key]
print(" {0:}".format(val))
return val
except:
print(" exc")
raise #IndexError
def __iter__(self):
print(" __iter__()")
self.__iter = iter(self.x)
return self
def __next__(self):
print(" __next__()")
if self.__iter is None:
raise StopIteration
val = next(self.__iter)
if isinstance(self.x, (dict,)): # Special handling for dictionaries
val = self.x[val]
return val
def key_transformer(int_key):
return str(int_key) # You could `return int_key` to see that it also works on your original example
def dataset_example(inner, key_func=None):
if key_func is None:
key_func = lambda x: x
print("\nInner object: {0:}".format(inner))
sd = SimpleDataset(inner)
print("Dataset length: {0:d}".format(len(sd)))
print("\nIterating (old fashion way):")
for i in range(len(sd)):
print(" {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
print("\nIterating (Python (iterator protocol) way):")
for element in sd:
print(" {0:}".format(element))
print("\nTry building the list:")
l = list(sd)
print(" List: {0:}\n".format(l))
def main():
dict_size = 2
for inner, func in [
(torch.randn(2, 2), None),
({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer), # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
]:
dataset_example(inner, key_func=func)
if __name__ == "__main__":
print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
main()
print("\nDone.")
输出:
[cfati@CFATI-5510-0:e:\Work\Dev\StackOverflow\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: {'1': 86, '0': 25} __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.
您可能还需要检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET( IterableDataset )。