sgl-project/sglang-jax

[Feature] Multi-Token Prediction (MTP) support for sgl-jax

Open

#192 opened on Sep 17, 2025

View on GitHub
 (3 comments) (1 reaction) (1 assignee)Python (276 stars) (101 forks)auto 404
SpecDecodeenhancementhelp wanted

Description

Motivation

Current autoregressive language models generate tokens sequentially, which creates inherent bottlenecks in inference throughput. While speculative decoding techniques like EAGLE improve performance through draft-verify mechanisms, they still rely on single-token predictions from the base model. Multi-Token Prediction addresses this limitation by enabling the model to directly predict multiple tokens, reducing the number of forward passes required for sequence generation.

Road Map

  • Eagle Worker main process @SiqiLi-Fighting

    • adapted tree mask to rpa v3 attention kernel
    • add bigram key prefix cache to radix cache
  • Performance Optimization @SiqiLi-Fighting

    • eagle topk = 1, JIT functional optimization
    • eagle topk > 1, bulid tree mask kernel at draft decode stage
    • Compatibility with SchedulerOverlap
    • non greedy sampling kernel implement @SiqiLi-Fighting
  • More Speculative algorithms Support (Call for Contribution)

    • ngram algorithms like PLD/Suffix Decoding/ LookAhead adaptation

Contributor guide