跳转至

Python类型标注

机器学习越来越火,大量的机器学习包都支持Python,导致了Python近几年非常火爆,入手门槛低,编程简单,概念非常少。越来越多的新手小白加入到Python编程。

Python虽然简单,但也带来很多问题。尤其是弱类型一直被诟病,平时在写代码时,经常也会模糊参数的类型,导致debug难度增加。

自从Python3.5以来,发布了typing包,推荐标注类型,并且IDE会检查类型,让Python看起来有点静态语言的感觉了。本文主要参考Python3.7.5的官方文档 1

常用的类型

常用的几种类型,如int, float, str, List, Tuple, Dict。接下来看几个例子: 首先必须从typing中导入类型

from typing import List, Dict, Tuple
def greeting(name: str) -> str:
    return 'Hello ' + name

print(greeting(188))

如果你输入的参数不是str,IDE就会提示。

如果传入List, Tuple, Dict,需要用[]来指定内部基础类型.

def do_nothing(a: List[int], b: Tuple[int, str], c: Dict[str, int]) -> Dict[str, str]:
    return {'key': 'value'}

# 调用
do_nothing([19], (14, 'some'), {"ha": 10})

除此之外基础类型是可以相互嵌套的,比如字典的值是列表,列表中存放元组,于是可以这样标注:

def do_nothing(a: Dict[str, List[Tuple[int, int, int]]], ):
    pass

比较复杂的类型标注用起来并不是很方便,我们可以根据业务指定别名。比如计算一个点与多个点的距离,List中存放的Tuple是点Point,坐标就是三维的浮点数,于是可以定义别名

Point = Tuple[float, float, float]


def compute_distance(p1: Point, points: List[Point]):
    pass

多种类型

不是多个参数,而是多类型,是输入参数可能存在多种类型,这种情况在Java中多态来解决。而Python本身是弱类型,输入参数没有强制规定,这个时候该怎么办?比如传入参数可能为int, str, float.typing包给我们提供了办法,可以用Union来定义:

Union[int, str, float]
输入参数必须是必须是int, str, float.其中之一。如果不确定数据的类型,可以标示为Any类型,表示任意类型。如果输入参数可能是None值,也可以用Union定义:

Union[str, None]
# 或者
Optional[str]

函数作为输入参数

如果函数作为输入参数,如何标记类型呢?其实也不复杂,函数是callable的类型,同样指定传入和传出参数即可。我们来看一个求和的函数,第一个参数就是函数。add_all只是把所有的元素相加,至于对每个元素做什么操作,取决于传入的函数了。

def add_all(f: Callable[[int], int], params: List[int]):
    return sum(map(f, params))

print(add_all(lambda x: x**2, list(range(1, 10))))
这里要注意的是函数作为参数,有输入和输出值。定义较为麻烦,func: Callable[[int], int],输入参数内部嵌套了中括号,仔细想想也能明白,如果func: Callable[int, int]定义,那么输入参数和输出参数该怎么理解呢?想明白了,你就理解了。

返回生成器

生成器在Python是非常常用的,可以很大提高程序的运行效率。如果需要返回生成器对象该怎么做呢?从typing包中导入Generator.我们来看一下例子,输入列表list,需要每次返回num个数据块。

from typing import List, Generator
import math


def get_data(l: List[int], num: int) -> Generator:
    """
    输入list, 每次按照num个数 返回数据块
    :param l: list data
    :param num: batch size
    """
    epochs = math.ceil(len(l) / num)

    for epoch in range(epochs):
        yield l[epoch * num:(epoch + 1) * num]


for each in get_data(list(range(98)), 5):
    print(each)

小结

本文分别列举了常用参数的类型标注方法,同时也给出了多种参数类型,以及函数和生成器作为参数输入的类型标注方法。参数的类型标注是很重要的,一方面可以帮助你理解每个参数的类型,另一方面也增强了代码的可读性。尤其是别人读到你的代码,调用起来会清晰很多。更多详细的说明可以查看官方文档或者源码。

关注我

评论