如何强制 Pandas read_csv 对所有浮点列使用 float32?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/30494569/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to force pandas read_csv to use float32 for all float columns?
提问by Fabian
Because
因为
- I don't need double precision
- My machine has limited memory and I want to process bigger datasets
- I need to pass the extracted data (as matrix) to BLAS libraries, and BLAS calls for single precision are 2x faster than for double precision equivalence.
- 我不需要双精度
- 我的机器内存有限,我想处理更大的数据集
- 我需要将提取的数据(作为矩阵)传递给 BLAS 库,并且 BLAS 对单精度的调用比双精度等价的调用快 2 倍。
Note that not all columns in the raw csv file have float types. I only need to set float32 as the default for float columns.
请注意,并非原始 csv 文件中的所有列都具有浮点类型。我只需要将 float32 设置为浮动列的默认值。
采纳答案by Alexander
Try:
尝试:
import numpy as np
import pandas as pd
# Sample 100 rows of data to determine dtypes.
df_test = pd.read_csv(filename, nrows=100)
float_cols = [c for c in df_test if df_test[c].dtype == "float64"]
float32_cols = {c: np.float32 for c in float_cols}
df = pd.read_csv(filename, engine='c', dtype=float32_cols)
This first reads a sample of 100 rows of data (modify as required) to determine the type of each column.
这首先读取 100 行数据的样本(根据需要修改)以确定每列的类型。
It the creates a list of those columns which are 'float64', and then uses dictionary comprehension to create a dictionary with these columns as the keys and 'np.float32' as the value for each key.
它创建一个包含“float64”的列的列表,然后使用字典理解来创建一个以这些列作为键和“np.float32”作为每个键的值的字典。
Finally, it reads the whole file using the 'c' engine (required for assigning dtypes to columns) and then passes the float32_cols dictionary as a parameter to dtype.
最后,它使用“c”引擎读取整个文件(需要为列分配 dtype),然后将 float32_cols 字典作为参数传递给 dtype。
df = pd.read_csv(filename, nrows=100)
>>> df
int_col float1 string_col float2
0 1 1.2 a 2.2
1 2 1.3 b 3.3
2 3 1.4 c 4.4
>>> df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3 entries, 0 to 2
Data columns (total 4 columns):
int_col 3 non-null int64
float1 3 non-null float64
string_col 3 non-null object
float2 3 non-null float64
dtypes: float64(2), int64(1), object(1)
df32 = pd.read_csv(filename, engine='c', dtype={c: np.float32 for c in float_cols})
>>> df32.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3 entries, 0 to 2
Data columns (total 4 columns):
int_col 3 non-null int64
float1 3 non-null float32
string_col 3 non-null object
float2 3 non-null float32
dtypes: float32(2), int64(1), object(1)
回答by Bstampe
@Alexander's is a great answer. Some columns may need to be precise. If so, you may need to stick more conditionals into your list comprehension to exclude some columns the anyor allbuilt ins are handy:
@Alexander's 是一个很好的答案。某些列可能需要精确。如果是这样,您可能需要在列表理解中加入更多条件以排除某些列any或all内置函数很方便:
float_cols = [c for c in df_test if all([df_test[c].dtype == "float64",
not df_test[c].name == 'Latitude', not df_test[c].name =='Longitude'])]

