在Python函数中执行类型检查

时间:2018-08-30 14:41:08

标签: types typechecking

我正在尝试编写强制执行类型检查的Python函数。我尝试执行此操作的方法是在函数的第一行中使用assertisinstance(),如下所示:

import numpy as np
import pandas as pd 

array_like = Union[pd.core.series.Series, np.ndarray]
LOG_TRANSFORM_CONST = 1.01


def log_transform(feature: array_like) -> array_like:
    assert isinstance(feature, array_like)

    # First remove negative entries
    feature[feature < 0.0] = 0.0

    # Add a small constant to avoid NANs while applying logs
    feature = feature + LOG_TRANSFORM_CONST

    return np.log(feature)

此代码不起作用,因为您无法将Unionisinstance()一起使用。但是,以下代码可以正常工作:

def log_transform(feature: array_like) -> array_like:
    assert type(feature) in [pd.core.series.Series, np.ndarray]

    # First remove negative entries
    feature[feature < 0.0] = 0.0

    # Add a small constant to avoid NANs while applying logs
    feature = feature + LOG_TRANSFORM_CONST

    return np.log(feature)

if __name__ == '__main__':
    df = pd.DataFrame(columns=['A', 'B'])
    df['A'] = [1, 2, 3, 4]
    df['B'] = [10, 20, 30, 40]
    tr_arr = log_transform(df.A)
    print(tr_arr)
    y = log_transform(np.array([2, 4, 6, 8, 10]))
    print(y)

我的问题是这种做法是否可取。关于Python中类型检查的最佳实践是什么?我知道可以专门安装第三方库来进行类型检查,但是我正努力避免这种情况。

2 个答案:

答案 0 :(得分:0)

尝试使用断言检查类型有局限性。首先,在运行时检查类型,这样您就不会很快发现错误,即使没有执行带有断言的代码,也不会发现错误。其次,某些类型无法使用断言进行检查。例如,您不能断言变量的类型为“函数从数字到数字”。

使用类型检查工具是最佳选择。您可以尝试mypy;它的主要贡献者之一是Guido van Rossum,所以它是合法的:D

答案 1 :(得分:0)

pythonic方法是使用isinstance(),但是isinstance采用类型或类型的 tuple ,而不是并集。它不需要列表或其他任何东西,只需要一个元组即可。我个人觉得这有点烦人,因为我很少使用python元组,但这是重点。

如果将array_like联合更改为:

array_like = (pd.core.series.Series, np.ndarray)

然后您的代码将使用assert isinstance(feature,array_like)

完整代码:

import numpy as np
import pandas as pd
from typing import Union

array_like = (pd.core.series.Series, np.ndarray)
LOG_TRANSFORM_CONST = 1.01


def log_transform(feature: array_like) -> array_like:
    assert isinstance(feature, array_like)

    # First remove negative entries
    feature[feature < 0.0] = 0.0

    # Add a small constant to avoid NANs while applying logs
    feature = feature + LOG_TRANSFORM_CONST

    return np.log(feature)

if __name__ == '__main__':
    df = pd.DataFrame(columns=['A', 'B'])
    df['A'] = [1, 2, 3, 4]
    df['B'] = [10, 20, 30, 40]
    tr_arr = log_transform(df.A)
    print(tr_arr)
    y = log_transform(np.array([2, 4, 6, 8, 10]))
    print(y)

希望这会有所帮助!