sgl-project/sglang-jax
Vedi su GitHub[Feature] Multi-Token Prediction (MTP) support for sgl-jax
Open
Aperta il 17 set 2025
SpecDecodeenhancementhelp wanted
Descrizione
Motivation
Current autoregressive language models generate tokens sequentially, which creates inherent bottlenecks in inference throughput. While speculative decoding techniques like EAGLE improve performance through draft-verify mechanisms, they still rely on single-token predictions from the base model. Multi-Token Prediction addresses this limitation by enabling the model to directly predict multiple tokens, reducing the number of forward passes required for sequence generation.
Road Map
-
Eagle Worker main process @SiqiLi-Fighting
- adapted tree mask to rpa v3 attention kernel
- add bigram key prefix cache to radix cache
-
Performance Optimization @SiqiLi-Fighting
- eagle topk = 1, JIT functional optimization
- eagle topk > 1, bulid tree mask kernel at draft decode stage
- Compatibility with SchedulerOverlap
- non greedy sampling kernel implement @SiqiLi-Fighting
-
More Speculative algorithms Support (Call for Contribution)
- ngram algorithms like PLD/Suffix Decoding/ LookAhead adaptation