我尝试手动执行一些混合精度,因此尝试以下模型
import tensorflow as tf
from tensorflow.keras import layers
class C3BR(tf.keras.Model):
def __init__(self, filterNum, kSize, strSize, padMode, dFormat='channels_first', dataType='float32'):
super(C3BR, self).__init__()
if dFormat == 'channels_first':
self.conAx = 1
else:
self.conAx = -1
self.kSize = (kSize, kSize, kSize)
self.conv = layers.Conv3D(filters=filterNum, kernel_size=self.kSize, strides=strSize, padding=padMode, data_format=dFormat, dtype=dataType)
self.BN = layers.BatchNormalization(axis=self.conAx, dtype='float32')
self.Relu = layers.ReLU(dtype='float32')
self.dataType = dataType
def call(self, inputs, ifTrain=False):
inputs = tf.dtypes.cast(inputs, self.dataType)
x = self.conv(inputs)
x = layers.Activation('linear', dtype='float32')(x)
x= self.BN(x, training=ifTrain)
print(x.dtype)
outputs = self.Relu(x)
print(outputs.dtype)
return outputs
x=tf.random.uniform((16, 2, 12, 12, 12), dtype='float32')
model=C3BR(32, 3, 1, 'valid', dataType='float16')
y=model(x)
我将输入转换为float16,然后在该层的中间,将其转换为float32。但是我得到的打印输出是
<dtype: 'float16'>
<dtype: 'float16'>
那我在哪里做错了什么?为什么?