我想使用割炬内置的DataLoader加载我需要在GPU上并行训练Tree-LSTM网络的数据。我正在关注this tutorial,但是在尝试遍历DataLoader对象时遇到了KeyError。
我的数据结构,代表一棵树(无环图),是通过一组嵌套的字典和列表以该结构递归实现的
{node (torch.Tensor) : [children] (torch.Tensor) and/or dicts with same structure}
例如:
{tensor([ 1.77, 2.06, 0.00, -1.06, 1.45, -2.92]): [tensor([ 0.63, 1.06, 0.00, -1.21, 1.49, 0.30]),
{tensor([ 1.12, 1.31, 0.00, 0.59, 0.80, -0.96]): [tensor([ 0.07, 0.19, 0.00, -1.29, 0.38, 0.30]),
tensor([-0.17, 0.25, 0.00, 0.37, -0.53, 0.00]),
tensor([ 0.10, 0.17, 0.00, -1.44, -0.17, -1.30]),
tensor([1.12, 1.00, 0.00, 0.55, 0.58, 0.18])]},
tensor([ 0.43, 0.69, 0.00, -0.75, -0.83, 0.43]),
{tensor([ 1.88, 2.19, 0.00, -1.26, 2.04, 0.14]): [tensor([ 0.11, 0.13, 0.00, -1.22, 1.19, -0.44]),
tensor([ 0.95, 0.97, 0.00, 0.32, 0.26, -0.18])]},
{tensor([ 1.38, 1.50, 0.00, -0.00, 0.53, -0.28]): [tensor([ 0.44, 0.50, 0.00, 0.02, 0.48, -1.37]),
tensor([-0.21, -0.33, 0.00, -0.36, 0.67, -1.76]),
tensor([-0.11, 0.06, 0.00, -1.52, 1.06, 0.14])]},
tensor([ 0.24, 0.69, 0.00, -1.32, 0.41, 0.19]),
tensor([ 1.00, 1.10, 0.00, -1.44, 0.37, 0.21]),
tensor([-0.02, -0.11, 0.00, -1.36, 1.75, -0.18]),
tensor([-0.21, -0.19, 0.00, 0.29, -0.57, -0.42]),
tensor([ 1.06, -0.17, 0.00, -1.26, 0.06, -0.85]),
tensor([-0.27, 0.10, 0.00, 0.89, -0.35, -0.11]),
tensor([ 0.61, 0.75, 0.00, -1.24, 1.36, -0.88]),
tensor([2.78, 2.74, 0.00, 0.15, 1.81, 0.21]),
{tensor([1.74, 1.98, 0.00, 1.03, 0.47, 0.18]): [tensor([ 0.69, 0.72, 0.00, -1.14, 0.53, -0.44]),
tensor([ 0.15, 0.40, 0.00, 0.32, -0.53, -0.76]),
tensor([-1.34, -1.46, 0.00, -0.42, -1.00, -0.92]),
tensor([-0.84, -1.12, 0.00, 1.50, -0.25, -0.02]),
tensor([-0.34, -0.30, 0.00, -0.94, 0.10, -0.21])]},
tensor([1.73, 1.91, 0.00, 1.86, 0.45, 0.15])]}