如何将int32张量转换为float32

时间:2019-07-29 00:54:54

标签: python tensorflow

如何在张量流中将int32张量转换为float32。我不了解tf.cast的作用。它似乎什么也没做。

import tensorflow as tf
import numpy as np

tf.enable_eager_execution()

a = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]])
b = tf.cast(a, dtype=tf.float32)

print(tf.shape(a))
print(tf.shape(b))

输出;

tf.Tensor([2 4], shape=(2,), dtype=int32) #a   
tf.Tensor([2 4], shape=(2,), dtype=int32) #b

1 个答案:

答案 0 :(得分:4)

如果您只是使用;

print(a)
print(b)

您将获得正确的结果;

tf.Tensor(
[[1 2 3 4]
 [1 2 3 4]], shape=(2, 4), dtype=int32) #a
tf.Tensor(
[[1. 2. 3. 4.]
 [1. 2. 3. 4.]], shape=(2, 4), dtype=float32) #b

因此tf.cast()可以正常工作!


使用tf.shape(),您将得到一个解释输入形状细节的结果。

  

返回:类型为out_type的张量。

     

out_type :(可选)操作的指定输出类型(int32   或int64)。默认为tf.int32

因此,dtype结果的tf.shape()是结果“ 形状细节张量”的dtype,而不是a的{​​{1}},或者b