别再「浪费」GPU了,FlashAttention重磅升级,实现长文本推理速度8倍提升
(来源:机器之心)
-
首先,将键 / 值分成更小的块; -
使用 FlashAttention 并行计算查询与每个这些分块的注意力,为每行和每个分块额外写入一个标量值:注意力值的 log-sum-exp -
最后,通过对所有分块进行归约来计算实际输出,使用 log-sum-exp 来调整每个分块的贡献。
-
Pytorch:使用纯粹的 PyTorch 基元来运行注意力计算(不使用 FlashAttention); -
FlashAttention v2; -
FasterTransformer:使用 FasterTransformer 的注意力内核; -
Flash-Decoding; -
以及一个上限值,该值计算了从内存中读取整个模型和 KV-cache 所需的时间
-
FlashAttention 包,从 v2.2 开始:https://github.com/Dao-AILab/flash-attention/tree/main -
xFormers 包(搜索 xformers.ops.memory_efficient_attention),从 0.0.22 开始:调度程序将根据问题的大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。