如何强制pandas read_csv为所有浮点列使用float32?

时间:2015-05-27 23:02:40

标签: python numpy pandas

由于

  • 我不需要双精度
  • 我的机器内存有限,我想处理更大的数据集
  • 我需要将提取的数据(作为矩阵)传递给BLAS库,并且BLAS调用单精度比双精度等效快2倍。

请注意,原始csv文件中的所有列都不具有浮点类型。我只需要将float32设置为float列的默认值。

4 个答案:

答案 0 :(得分:9)

尝试:

import numpy as np
import pandas as pd

# Sample 100 rows of data to determine dtypes.
df_test = pd.read_csv(filename, nrows=100)

float_cols = [c for c in df_test if df_test[c].dtype == "float64"]
float32_cols = {c: np.float32 for c in float_cols}

df = pd.read_csv(filename, engine='c', dtype=float32_cols)

首先读取100行数据的样本(根据需要进行修改)以确定每列的类型。

它创建了一个列为'float64'的列,然后使用字典理解来创建一个字典,其中这些列作为键,'np.float32'作为每个键的值。

最后,它使用'c'引擎(将dtypes分配给列所需)读取整个文件,然后将float32_cols字典作为参数传递给dtype。

df = pd.read_csv(filename, nrows=100)
>>> df
   int_col  float1 string_col  float2
0        1     1.2          a     2.2
1        2     1.3          b     3.3
2        3     1.4          c     4.4

>>> df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3 entries, 0 to 2
Data columns (total 4 columns):
int_col       3 non-null int64
float1        3 non-null float64
string_col    3 non-null object
float2        3 non-null float64
dtypes: float64(2), int64(1), object(1)

df32 = pd.read_csv(filename, engine='c', dtype={c: np.float32 for c in float_cols})
>>> df32.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3 entries, 0 to 2
Data columns (total 4 columns):
int_col       3 non-null int64
float1        3 non-null float32
string_col    3 non-null object
float2        3 non-null float32
dtypes: float32(2), int64(1), object(1)

答案 1 :(得分:0)

@Alexander的答案很好。有些列可能需要精确。如果是这样,您可能需要在列表理解中加入更多条件,以排除内置anyall方便的某些列:

float_cols = [c for c in df_test if all([df_test[c].dtype == "float64", 
             not df_test[c].name == 'Latitude', not df_test[c].name =='Longitude'])]           

答案 2 :(得分:0)

如果您不在乎列顺序,那么还有df.select_dtypes可以避免两次read_csv

import pandas as pd

df = pd.read_csv("file.csv")

df_float = df.select_dtypes(include=float).astype("float32")
df_not_float = df.select_dtypes(exclude=float)

df = df_float.join(df_not_float)

或者,如果您要将 all 个非字符串列(例如整数列)转换为float:

import pandas as pd

df = pd.read_csv("file.csv")

df_not_str = df.select_dtypes(exclude=object).astype("float32")
df_str = df.select_dtypes(include=object)

df = df_not_str.join(df_str)

答案 3 :(得分:0)

这是一种不依赖.join或不需要两次读取文件的解决方案:

float64_cols = df.select_dtypes(include='float64').columns
mapper = {col_name: np.float32 for col_name in float64_cols}
df = df.astype(mapper)

或者作为单线踢球:

df = df.astype({c: np.float32 for c in df.select_dtypes(include='float64').columns})
相关问题