vllm-project/vllm

[Bug]: FlashInfer Incompatible with Sleep Mode

Open

#31016 opened on Dec 19, 2025

View on GitHub
 (3 comments) (2 reactions) (0 assignees)Python (80,034 stars) (16,816 forks)batch import
bughelp wanted

Description

Your current environment

Your output of `python collect_env.py` here

🐛 Describe the bug

Here is a script to reproduce the bug: I use vllm=v0.10.1 and flashinfer-python=v0.5.3.

from vllm import LLM, SamplingParams

if __name__ == "__main__":
    model_pth = "xxx/Qwen3-1.7B"  
    tp_size = 1
    llm = LLM(
        model=model_pth, 
        enable_sleep_mode=True,
        tensor_parallel_size=tp_size,
        gpu_memory_utilization=0.7, 
    )

    llm.sleep(level=1)
    llm.wake_up()

    prompts = [
        "What is AI?", 
        "Where is the Machu Picchu located?", 
        "What is the capital of France?",
        "Who painted the Mona Lisa?",
    ]

    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.9,
        max_tokens=64,
    )

    outputs = llm.generate(prompts, sampling_params)

    for i, out in enumerate(outputs):
        prompt = prompts[i]
        generated = out.outputs[0].text
        print(f"Prompt {i}: {prompt!r}")
        print(f"Generation: {generated}\n")

Root Cause

The bug occurs because the FlashInfer backend’s attn_metadata is stateful. It holds a block_table_arange tensor that is initialized once and then reused across subsequent calls to build:

self.block_table_arange = torch.arange(
    max_num_pages_per_req,
    dtype=torch.int32,
    device=self.device,
)

This block_table_arange tensor is allocated in the mempool with the "kv_cache" tag. It gets discarded after calling llm.sleep, but is not recreated when the engine wakes up, which leads to incorrect values and thus wrong outputs.

Specifically, this will cause bad rollout outputs in VERL using vllm + flashinfer.

Temporary Fix

Here is a patch as a temporary workaround. It’s not an ideal solution, but it works:

from vllm.v1.attention.backends.flashinfer import FlashInferMetadataBuilder
import torch

def patch_flashinfer_build():
    old_build = FlashInferMetadataBuilder.build

    def new_build(*args, **kwargs):
        self = args[0]
        max_num_pages_per_req = self.block_table_arange.numel()
        self.block_table_arange.copy_(
            torch.arange(
                max_num_pages_per_req,
                device=self.block_table_arange.device,
                dtype=self.block_table_arange.dtype,
            )
        )
        return old_build(*args, **kwargs)

    FlashInferMetadataBuilder.build = new_build

patch_flashinfer_build()

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Contributor guide