tf.cast相当于在pytorch?

时间:2019-12-02 04:25:14

标签: python tensorflow pytorch

我是PyTorch的新手。 TensorFlow具有API tf.cast() tf.shape() tf.cast 在TensorFlow中具有特定用途,火炬中是否有等效的东西? 我有张量x =张量(shape(128,64,32,32)): tf.shape(x)创建尺寸为1的张量 x.shape 创建真实尺寸。我需要在手电筒中使用 tf.shape(x)

tf.cast 的作用不同于在炬管中更改张量 dtype 的作用。

有人在割炬/ PyTorch中具有等效的API吗?

1 个答案:

答案 0 :(得分:2)

签出PyTorch Documentation

他们提到:

print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"