当元组在python中只有一个条目时,逗号运算符的意义是什么?

时间:2019-08-05 11:46:21

标签: python python-3.x pytorch

这部分代码摘自Pytorch tutorials中的一个,我刚刚删除了不必要的部分,因此它不会出错,并添加了一些打印语句。我的问题是,为什么我提供的两个打印语句的结果略有不同?这是一个下半部什么都没有的元组吗?逗号让赋值运算符之前的逗号令我感到困惑。

import torch

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        print("ctx ", ctx.saved_tensors)
        print("inputs ", input)
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

relu = MyReLU.apply
relu = MyReLU.apply
y_pred = relu(x.mm(w1)).mm(w2)
loss = (y_pred - y).pow(2).sum()
loss.backward()

输出

ctx  (tensor([[-34.2381,  18.6334,   8.8368,  ...,  13.7337, -31.5657, -11.8838],
        [-25.5597,  -6.2847,   9.9412,  ..., -75.0621,   5.0451, -32.9348],
        [-56.6591, -40.0830,   2.4311,  ...,  -2.8988, -18.9742, -74.0132],
        ...,
        [ -6.4023, -30.3526, -73.9649,  ...,   1.8587, -23.9617, -11.6951],
        [ -3.6425,  34.5828,  27.7200,  ..., -34.3878, -19.7250,  11.1960],
        [ 16.0137, -24.0628,  14.4008,  ...,  -5.4443,   9.9499, -18.1259]],
       grad_fn=<MmBackward>),)
inputs  tensor([[-34.2381,  18.6334,   8.8368,  ...,  13.7337, -31.5657, -11.8838],
        [-25.5597,  -6.2847,   9.9412,  ..., -75.0621,   5.0451, -32.9348],
        [-56.6591, -40.0830,   2.4311,  ...,  -2.8988, -18.9742, -74.0132],
        ...,
        [ -6.4023, -30.3526, -73.9649,  ...,   1.8587, -23.9617, -11.6951],
        [ -3.6425,  34.5828,  27.7200,  ..., -34.3878, -19.7250,  11.1960],
        [ 16.0137, -24.0628,  14.4008,  ...,  -5.4443,   9.9499, -18.1259]],
       grad_fn=<MmBackward>)

2 个答案:

答案 0 :(得分:2)

这只是解包单元素列表或元组的一种极端情况。

a, = [1]
print(type(a), a)
# <class 'int'> 1

如果没有逗号,a将被分配整个列表:

a = [1]
print(type(a), a)
# <class 'list'> [1]

元组也是如此:

a, = (1,)  # have to use , with literal single-tuples, because (1) is just 1
print(type(a), a)
# <class 'int'> 1

a = (1,)  # have to use , with literal single-tuples, because (1) is just 1
print(type(a), a)
# <class 'tuple'> (1,)

答案 1 :(得分:0)

(a,b)是一个二元组,(a,b,c)是一个三元组,(a,b,c,d)是一个四元组。

采用另一种方法(a)将是一个元组。但这与例如(1 + 2) / 3,因为您无法分割元组。由于单元组很少见,数学表达式中的括号很常见,因此( <expr> )不是元组。并且需要额外的结尾,例如(a, )

注意:(a,b,)和(a,b,c,)也可以工作。

打开元组的包装也是如此:

a,=元组

解压缩元组并将a设置为第一个(也是唯一一个)。