根据我的测试,Keras中的to_categorical()
会返回ndarray
float64
。我想知道为什么默认情况下不是float32
,可以由GPU处理。据我所知,GPU无法处理float64
。 to_categorical()
的{{3}}并未说明返回类型应该是什么。所以,我想它可能是一个实现细节,而不是协议/接口的一部分。总结一下,有两个问题:
float64
而不是float32
?astype()
来电。答案 0 :(得分:2)
问题在于numpy.zeros
中的to_categorical
函数used。默认情况下,它会创建一个类型为float64
的数组。
不幸的是,我建议您测试解决方案以解决当前数据类型不保证的问题。通常情况下 - 大多数变换器都会以提供的格式返回数据,因此只要您的基础数据位于float32
- 它就会保留float32
。但是有一些边缘情况,如to_categorical
。
在我的项目中,我使用docker
来保持我用于培训/推理的所有机器的一致性。