bughelp wantedkeep-opentorch.compile
Description
Your current environment
main
🐛 Describe the bug
When running a model with quant_fp8 + flashinfer on B200s, the matmul kernel that gets used is flashinfer_scaled_fp8_mm which turns into a vllm.bmm_fp8 op in the graph. This doesn't work with the AsyncTP pass in the torch.compile compilation as there doesn't exist a pattern/replacement for this.
Following the other patterns, I wrote one for bmm_fp8:
https://github.com/vllm-project/vllm/pull/26933/commits/b0ab87b121acd1d4c52f3fbee12c3a447ea8f6b4
However, perf is a lot worse:
This might be because we're replacing the bmm_fp8 + reduce_scatter with torch.ops.symm_mem.patched_fused_scaled_matmul_reduce_scatter, but this op does not have a B200 specific implementation (it just calls into aten._scaled_mm)
cc @ProExpertProg @cascade812
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.