我正在尝试编写强制执行类型检查的Python函数。我尝试执行此操作的方法是在函数的第一行中使用assert
和isinstance()
,如下所示:
import numpy as np
import pandas as pd
array_like = Union[pd.core.series.Series, np.ndarray]
LOG_TRANSFORM_CONST = 1.01
def log_transform(feature: array_like) -> array_like:
assert isinstance(feature, array_like)
# First remove negative entries
feature[feature < 0.0] = 0.0
# Add a small constant to avoid NANs while applying logs
feature = feature + LOG_TRANSFORM_CONST
return np.log(feature)
此代码不起作用,因为您无法将Union
与isinstance()
一起使用。但是,以下代码可以正常工作:
def log_transform(feature: array_like) -> array_like:
assert type(feature) in [pd.core.series.Series, np.ndarray]
# First remove negative entries
feature[feature < 0.0] = 0.0
# Add a small constant to avoid NANs while applying logs
feature = feature + LOG_TRANSFORM_CONST
return np.log(feature)
if __name__ == '__main__':
df = pd.DataFrame(columns=['A', 'B'])
df['A'] = [1, 2, 3, 4]
df['B'] = [10, 20, 30, 40]
tr_arr = log_transform(df.A)
print(tr_arr)
y = log_transform(np.array([2, 4, 6, 8, 10]))
print(y)
我的问题是这种做法是否可取。关于Python中类型检查的最佳实践是什么?我知道可以专门安装第三方库来进行类型检查,但是我正努力避免这种情况。
答案 0 :(得分:0)
尝试使用断言检查类型有局限性。首先,在运行时检查类型,这样您就不会很快发现错误,即使没有执行带有断言的代码,也不会发现错误。其次,某些类型无法使用断言进行检查。例如,您不能断言变量的类型为“函数从数字到数字”。
使用类型检查工具是最佳选择。您可以尝试mypy;它的主要贡献者之一是Guido van Rossum,所以它是合法的:D
答案 1 :(得分:0)
pythonic方法是使用isinstance(),但是isinstance采用类型或类型的 tuple ,而不是并集。它不需要列表或其他任何东西,只需要一个元组即可。我个人觉得这有点烦人,因为我很少使用python元组,但这是重点。
如果将array_like联合更改为:
array_like = (pd.core.series.Series, np.ndarray)
然后您的代码将使用assert isinstance(feature,array_like)
完整代码:
import numpy as np
import pandas as pd
from typing import Union
array_like = (pd.core.series.Series, np.ndarray)
LOG_TRANSFORM_CONST = 1.01
def log_transform(feature: array_like) -> array_like:
assert isinstance(feature, array_like)
# First remove negative entries
feature[feature < 0.0] = 0.0
# Add a small constant to avoid NANs while applying logs
feature = feature + LOG_TRANSFORM_CONST
return np.log(feature)
if __name__ == '__main__':
df = pd.DataFrame(columns=['A', 'B'])
df['A'] = [1, 2, 3, 4]
df['B'] = [10, 20, 30, 40]
tr_arr = log_transform(df.A)
print(tr_arr)
y = log_transform(np.array([2, 4, 6, 8, 10]))
print(y)
希望这会有所帮助!