我是PyTorch的新手。 TensorFlow具有API tf.cast()和 tf.shape()。 tf.cast 在TensorFlow中具有特定用途,火炬中是否有等效的东西? 我有张量x =张量(shape(128,64,32,32)): tf.shape(x)创建尺寸为1的张量 x.shape 创建真实尺寸。我需要在手电筒中使用 tf.shape(x)。
tf.cast 的作用不同于在炬管中更改张量 dtype 的作用。
有人在割炬/ PyTorch中具有等效的API吗?
答案 0 :(得分:2)
他们提到:
print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"