sgl-project/sglang-jax

[Feature] Multi-Token Prediction (MTP) support for sgl-jax

Open

Aperta il 17 set 2025

Vedi su GitHub
 (3 commenti) (1 reazione) (1 assegnatario)Python (276 star) (101 fork)auto 404
SpecDecodeenhancementhelp wanted

Descrizione

Motivation

Current autoregressive language models generate tokens sequentially, which creates inherent bottlenecks in inference throughput. While speculative decoding techniques like EAGLE improve performance through draft-verify mechanisms, they still rely on single-token predictions from the base model. Multi-Token Prediction addresses this limitation by enabling the model to directly predict multiple tokens, reducing the number of forward passes required for sequence generation.

Road Map

  • Eagle Worker main process @SiqiLi-Fighting

    • adapted tree mask to rpa v3 attention kernel
    • add bigram key prefix cache to radix cache
  • Performance Optimization @SiqiLi-Fighting

    • eagle topk = 1, JIT functional optimization
    • eagle topk > 1, bulid tree mask kernel at draft decode stage
    • Compatibility with SchedulerOverlap
    • non greedy sampling kernel implement @SiqiLi-Fighting
  • More Speculative algorithms Support (Call for Contribution)

    • ngram algorithms like PLD/Suffix Decoding/ LookAhead adaptation

Guida contributor