Numba 常见问题与解决方案

最佳实践 发布于 Jun 9, 2024 更新于 Jun 10, 2024

Numba是一个复杂的高性能解决方案,由于我工作的特点,会经常(被迫)需要使用Python完成一些对性能要求极高的数据处理工具,导致我经常与Numba斗智斗勇。Numba能够提供极大的性能提升,在Python里快速开发具有C性能的数据处理逻辑,但同样的,Numba也是一个复杂且坑点重重的解决方案。

np.dtype的类型签名

Numba对Numpy ufunc和数据类型有良好的兼容,直接将ndarray对象传入jitted func,Numba可以自动识别类型签名。但是某些时候,我们需要显式指定这些ndarray的签名,比如:定义有ndarray成员的jitclass,定义empty typed list等需要立即而非Lazy确定数据类型签名的场景。

使用内置dtype的ndarray

import numpy as np
from numba import experimental as nbe

@nbe.jitclass([('arr', nb.types.float64[:])])
class JITClass:
    def __init__(self):
        self.arr = np.zeros(100, dtype=np.float64)

Numba的基本数据类型和Numpy存在映射关系,nb.float64和np.float64编译后是同一种数据类型,在装饰器的spec参数里填写参数列表的类型,np.float64数组就是nb.types.float64[:]。

使用自定义dtype的ndarray

这里需要用到Numba提供的工具函数nb.from_dtype,可以将定义好的dtype对象转换为Numba可识别的类型签名。

import numpy as np
import numba as nb

my_type = np.dtype([('a', np.int32), ('b', np.float64)])
print(nb.from_dtype(my_type))

执行这个脚本,可以打印出这个dtype的Numba类型签名Record(a[type=int32;offset=0],b[type=float64;offset=4];12;False)

如果要定义一个ndarray,需要使用nb.from_dtype(my_type)[:],因为返回的Record是“一行”(也就是标量)的类型。但是特殊地,Numba似乎也不支持将Numpy标量作为jitclass的成员,至少也得用ndarray包成向量Numba才能接受。

举个例子:

import numpy as np
from numba import experimental as nbe

# DTYPE_TRADE_EVENT 是我自定义的dtype

@nbe.jitclass([('_event', nb.from_dtype(DTYPE_TRADE_EVENT)[:])])
class EventWrapper:

    def __init__(self, event: np.ndarray):
        self._event = event

    @property
    def event(self) -> np.ndarray:
        return self._event[0]

我为了在一个JIT队列里面传递DTYPE_TRADE_EVENT类型的标量,设计了一个wrapper class,用一个一维向量保存这个标量(当然这个class还有其他的参数和功能,比如当作排序容器,没有展示在这里)。

提前编译

出于各种考量,可能是测量性能,也可能是定位性能问题,我们不希望需要JIT的函数在首次调用时才Lazy地进行编译,而是执行时函数已经就绪。一般来说,有三种方案可以解决这个问题:

  • AOT(Ahead of Time):提前编译,可以打成外部库,性能最强,限制最多
  • Dummy Call:import模块的时候用一组伪调用提前调用要JIT的函数,并且按未来实际使用的类型提供参数
  • Eager Compilation:在@jit修饰时就指定函数签名,定义时立刻编译,调用时不再支持其它类型的参数

何处发生了编译

用Numba进行加速时,总是需要反复进行测量,观察运行时间,然而Numba的编译并不算快,并不复杂的函数即时编译需要以秒计时,这可能导致对代码段运行效率的评估有错误的结论。这就导致我需要知道何处发生了预期外的编译。

Numba提供了编译函数的重载功能,我们可以把它拦截下来,打印栈信息,用这种方式可以快速定位到发生JIT的函数。可以理解为一个numba编译hook函数。

import traceback
import numba as nb

compile = nb.core.registry.CPUDispatcher.compile
def jit_probe(*args, **kwargs):
    print(f"------\ncompile {args[0]}:\n{''.join(traceback.format_stack())}------")
    return compile(*args, **kwargs)

nb.core.registry.CPUDispatcher.compile = jit_probe

效果:

$ python3.8 testnb.py 
------
compile CPUDispatcher(<function test at 0x7fa4b466a280>):
  File "testnb.py", line 22, in <module>
    test()
  File "/home/kino/.local/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
  File "testnb.py", line 11, in jit_probe
    print(f"------\ncompile {args[0]}:\n{''.join(traceback.format_stack())}------")
------
$

下面一个问题强化了这个函数的功能,可以打印新增的函数签名,可以继续往下看。

是什么导致了重新编译

进行Numba+Numpy的高性能开发时,会遇到一些难以定位的问题,这些问题会导致已经被编译的函数再次被编译,比如传入函数的参数类型由np.int32变成了np.int64。

这里需要使用.signatures 函数,这个函数可以获取已编译函数的所有签名组合。

Note:Numba在编译函数对象后,会为新的函数对象提供一些函数,因为没有签名,VSCode给jitted func后面加个'.'dot意图访问是不会有语法高亮和联想的。

略微修改一下上面的hook,可以得到个加强版的probe:

import traceback
import numba as nb

compile = nb.core.registry.CPUDispatcher.compile

def jit_probe(*args, **kwargs):
    r = compile(*args, **kwargs)
    print("------")
    print(f"compile {args[0]}:\n{''.join(traceback.format_stack())}")
    print(f"signature: {args[0].signatures[-1]}")
    print("------")
    return r

nb.core.registry.CPUDispatcher.compile = jit_probe

@nb.njit
def test(a):
    pass

test(1)
test(1.0)

效果:

$ python3.8 testnb.py 
------
compile CPUDispatcher(<function test at 0x7f660f191280>):
  File "testnb.py", line 23, in <module>
    test(1)
  File "/home/nji/.local/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
  File "testnb.py", line 9, in jit_probe
    print(f"compile {args[0]}:\n{''.join(traceback.format_stack())}")

added signature: (int64,)
------
------
compile CPUDispatcher(<function test at 0x7f660f191280>):
  File "testnb.py", line 24, in <module>
    test(1.0)
  File "/home/nji/.local/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
  File "testnb.py", line 9, in jit_probe
    print(f"compile {args[0]}:\n{''.join(traceback.format_stack())}")

added signature: (float64,)
------
$ 

Numba并没有提供专门的工具函数完成这个需求。但是,hook它的编译函数,用这种稍有些tricky的方式,可以很简单地拿到编译发生的位置,以及造成编译的参数签名列表。

这篇文章会继续完善,如果有什么需要补充的地方,可以在评论区留言。

标签

Noam Chi

An Innovative Quant Developer. 2018 VEX World Final THINK Award🏆