初识Numba应用

还是《Python高性能》一书,看到讲Numba这一章,又去瞄了几眼官方文档,感觉是一个非常实用的工具。

什么是JIT

JIT(Just-in-time compilation, 即时编译)

相关链接:
https://stackoverflow.com/questions/95635/what-does-a-just-in-time-jit-compiler-do

http://blog.reverberate.org/2012/12/hello-jit-world-joy-of-simple-jits.html

什么是Numba

Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.

Numba translates Python functions to optimized machine code at runtime using the industry-standard LLVM compiler library. Numba-compiled numerical algorithms in Python can approach the speeds of C or FORTRAN.

使用Numba时要注意,Numba每次推出新版本都可能有重大改进,有时还可能不向后兼容,请务必参阅每版的发行说明。

安装

1
pip install numba

入坑

以实现数组中元素平方和为例:

1
2
3
4
5
6
def sum_sq(a):
result = 0
n = len(a)
for i in range(n):
result += a[i]
return result

使用Numba对函数进行编译,应用jit装饰器,将在函数首次被调用时,检测输入参数类型,并编译出一个高性能的版本:

1
2
3
4
5
6
7
8
9
from numba import jit

@jit
def sum_sq(a):
result = 0
n = len(a)
for i in range(n):
result += a[i]
return result

通过timeit测试两个版本的执行时间,访问未经装饰的原始函数,使用属性py_func即可。

1
2
import numpy as np
x = np.random.rand(10000)
1
2
3
4
%timeit sum_sq.py_func(x)
1.44 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit sum_sq(x)
10.6 µs ± 55.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

可以明显看出,使用Numba极大的提高了速度。

nopython与object模式

Numba的优化程度取决于两个因素:

  • 能否准确地推断变量的类型;
  • 能否将标准Python操作转换为速度更快的、针对特定类型的版本。

如果Numba无法推导出变量的类型,对代码会仍然进行编译,但在类型无法确定或操作没有得到支持时,会求助于解释器。在Numba中,这被称为object模式,与之相对的是nopython模式。

Numba提供了inspect_types函数,可用于了解推断类型、哪些操作被优化。

比如查看上一节的sum_sq函数的类型推断,其中列出了有关变量及其类型的信息:

1
sum_sq.inspect_types()

输出太长,贴一部分好了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# --- LINE 4 --- 

def sum_sq(a):

# --- LINE 5 ---
# a = arg(0, name=a) :: array(float64, 1d, C)
# $const0.1 = const(int, 0) :: Literal[int](0)
# result = $const0.1 :: float64
# del $const0.1

result = 0

# --- LINE 6 ---
# $0.2 = global(len: <built-in function len>) :: Function(<built-in function len>)
# $0.4 = call $0.2(a, func=$0.2, args=[Var(a, <ipython-input-2-3afdff37c581> (5))], kws=(), vararg=None) :: (array(float64, 1d, C),) -> int64
# del $0.2
# n = $0.4 :: int64
# del $0.4
# jump 12
# label 12

n = len(a)

可以看到所有变量都有明确的类型。

接下来将官方文档的例子稍微修改,看一下nopython模式的使用:

1
2
import numpy as np
x = np.arange(10000).reshape(100, 100)
1
2
3
4
5
6
7
8
from numba import jit

@jit(nopython=True)
def go_fast(a):
trace = 0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
1
2
3
4
5
%timeit go_fast.py_func(x)
131 µs ± 652 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit go_fast(x)
The slowest run took 6.11 times longer than the fastest. This could mean that an intermediate result is being cached.
6.51 µs ± 6.33 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

反例

Numba并不是都会加速运行,有些情况下会增加额外的开销。
比如字符串拼接:

1
2
3
4
5
6
7
8
from numba import jit

@jit(nopython=True)
def concatenate(strings):
result = ''
for s in strings:
result += s
return result
1
x = ['hello'] * 1000
1
2
3
4
%timeit concatenate.py_func(x)
129 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit concatenate(x)
1.58 ms ± 46.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看出使用Numba编译的函数明显比原始函数执行的要慢很多。

0%