python-pytorch-使用DataLoader

时间:2018-12-21 18:02:07

标签: python neural-network lstm pytorch

我想使用割炬内置的DataLoader加载我需要在GPU上并行训练Tree-LSTM网络的数据。我正在关注this tutorial,但是在尝试遍历DataLoader对象时遇到了KeyError。

我的数据结构,代表一棵树(无环图),是通过一组嵌套的字典和列表以该结构递归实现的

{node (torch.Tensor) : [children] (torch.Tensor) and/or dicts with same structure}

例如:

{tensor([ 1.77,  2.06,  0.00, -1.06,  1.45, -2.92]): [tensor([ 0.63,  1.06,  0.00, -1.21,  1.49,  0.30]),
                                                  {tensor([ 1.12,  1.31,  0.00,  0.59,  0.80, -0.96]): [tensor([ 0.07,  0.19,  0.00, -1.29,  0.38,  0.30]),
                                                                                                        tensor([-0.17,  0.25,  0.00,  0.37, -0.53,  0.00]),
                                                                                                        tensor([ 0.10,  0.17,  0.00, -1.44, -0.17, -1.30]),
                                                                                                        tensor([1.12, 1.00, 0.00, 0.55, 0.58, 0.18])]},
                                                  tensor([ 0.43,  0.69,  0.00, -0.75, -0.83,  0.43]),
                                                  {tensor([ 1.88,  2.19,  0.00, -1.26,  2.04,  0.14]): [tensor([ 0.11,  0.13,  0.00, -1.22,  1.19, -0.44]),
                                                                                                        tensor([ 0.95,  0.97,  0.00,  0.32,  0.26, -0.18])]},
                                                  {tensor([ 1.38,  1.50,  0.00, -0.00,  0.53, -0.28]): [tensor([ 0.44,  0.50,  0.00,  0.02,  0.48, -1.37]),
                                                                                                        tensor([-0.21, -0.33,  0.00, -0.36,  0.67, -1.76]),
                                                                                                        tensor([-0.11,  0.06,  0.00, -1.52,  1.06,  0.14])]},
                                                  tensor([ 0.24,  0.69,  0.00, -1.32,  0.41,  0.19]),
                                                  tensor([ 1.00,  1.10,  0.00, -1.44,  0.37,  0.21]),
                                                  tensor([-0.02, -0.11,  0.00, -1.36,  1.75, -0.18]),
                                                  tensor([-0.21, -0.19,  0.00,  0.29, -0.57, -0.42]),
                                                  tensor([ 1.06, -0.17,  0.00, -1.26,  0.06, -0.85]),
                                                  tensor([-0.27,  0.10,  0.00,  0.89, -0.35, -0.11]),
                                                  tensor([ 0.61,  0.75,  0.00, -1.24,  1.36, -0.88]),
                                                  tensor([2.78, 2.74, 0.00, 0.15, 1.81, 0.21]),
                                                  {tensor([1.74, 1.98, 0.00, 1.03, 0.47, 0.18]): [tensor([ 0.69,  0.72,  0.00, -1.14,  0.53, -0.44]),
                                                                                                  tensor([ 0.15,  0.40,  0.00,  0.32, -0.53, -0.76]),
                                                                                                  tensor([-1.34, -1.46,  0.00, -0.42, -1.00, -0.92]),
                                                                                                  tensor([-0.84, -1.12,  0.00,  1.50, -0.25, -0.02]),
                                                                                                  tensor([-0.34, -0.30,  0.00, -0.94,  0.10, -0.21])]},
                                                  tensor([1.73, 1.91, 0.00, 1.86, 0.45, 0.15])]}

0 个答案:

没有答案