Skip to content

Conversation

@dcampora
Copy link
Contributor

@dcampora dcampora commented Dec 1, 2025

Motivation

This PR introduces support for model Mistral Large 3.

Modifications

To enable the model, several key modifications were made.

Accuracy Tests

Benchmarking and Profiling

Checklist

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @dcampora, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the system's model compatibility by integrating the Mistral Large 3 and PixtralForConditionalGeneration models. It includes crucial updates to the underlying infrastructure for handling Mixture-of-Experts (MoE) architectures, particularly focusing on advanced FP8 quantization techniques. The changes ensure that these new models are correctly configured, loaded, and can leverage optimized quantization paths, while also improving the system's ability to retrieve model-specific information.

Highlights

  • Mistral Large 3 Support: Added comprehensive support for the Mistral Large 3 model, including its specific architecture and weight remapping, enabling its use within the system.
  • PixtralForConditionalGeneration Integration: Introduced support for the PixtralForConditionalGeneration model, which includes a vision encoder, patch merger, and vision-language adapter, expanding multimodal capabilities.
  • FP8 Per-Tensor Quantization for MoE: Implemented FP8 per-tensor quantization for Mixture-of-Experts (MoE) layers, including weight alignment for FlashInfer kernels and precomputation of per-expert output scaling factors, optimizing performance for quantized MoE models.
  • Model Configuration and Routing Enhancements: Updated model configuration handling across various components to properly recognize and configure new models like Mistral Large 3 and Pixtral, and refactored the router manager to intelligently route model information requests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Mistral Large 3 and related models like Pixtral and Eagle. This is a significant feature addition, touching many parts of the codebase, including model configuration, quantization kernels, and the core model implementations. The changes are well-structured, especially the use of inheritance and remapping to adapt the new models to the existing DeepseekV2 architecture. The addition of a mistral_utils.py to handle Mistral's non-standard params.json is a good solution to a tricky problem.

I have a few suggestions to improve maintainability and robustness. Specifically, I've pointed out a redundant check in the FP8 quantization logic, a potential issue with silent failures in weight remapping, and a resource leak in temporary file handling.

Overall, this is a great contribution that expands the model support of sglang.

zhyncs and others added 2 commits December 1, 2025 02:24
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@slin1237
Copy link
Collaborator

slin1237 commented Dec 1, 2025

router code lgtm
make sure other pieces are reviewed

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
@yuan-luo yuan-luo added Multi-modal multi-modal language model vlm labels Dec 3, 2025
@ispobock
Copy link
Collaborator

ispobock commented Dec 3, 2025

gsm8k benchmark results:

SGLANG_ENABLE_JIT_DEEPGEMM=0 \
python3 -m sglang.launch_server \
--model mistralai/Mistral-Large-3-675B-Instruct-2512 \
--kv-cache-dtype fp8_e4m3 \
--tensor-parallel-size 8 \
--attention-backend trtllm_mla \
--model-loader-extra-config '{"enable_multithread_load": true}' \
--chat-template mistral

lm_eval \
--model local-chat-completions \
--model_args model=mistralai/Mistral-Large-3-675B-Instruct-2512,base_url=http://127.0.0.1:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=8192 \
--tasks gsm8k \
--batch_size 128 \
--apply_chat_template \
--num_fewshot 8

w/ FP8 attention:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.9181|±  |0.0076|
|     |       |strict-match    |     8|exact_match|↑  |0.7225|±  |0.0123|

w/o FP8 attention:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.9242|±  |0.0073|
|     |       |strict-match    |     8|exact_match|↑  |0.7528|±  |0.0119|

Linda-Stadter and others added 2 commits December 3, 2025 04:59
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
return output


class MistralLarge3ForCausalLMEagle(MistralLarge3ForCausalLM):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EAGLE seems not working.

if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"

self.llama_4_scaling = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add comment here for mistral model.

if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
w13_input_scale = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to update modelopt quant here?

self.process_weights_hip_scale_padding(layer)

# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if get_moe_runner_backend().is_flashinfer_trtllm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For flashinfer_trtllm moe, do we need to seperate a PR to support it?

bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add DeepGEMM support

@ispobock
Copy link
Collaborator

ispobock commented Dec 3, 2025

/tag-and-rerun-ci

@github-actions github-actions bot added the blackwell SM100/SM120 label Dec 3, 2025
@dcampora
Copy link
Contributor Author

dcampora commented Dec 3, 2025

@ispobock let's tackle Eagle3, DeepGEMM, flashinfer_trtllm moe and FP4 support in follow-up MRs. We have removed the Eagle3 code from the PR.

@ispobock
Copy link
Collaborator

ispobock commented Dec 4, 2025

/tag-and-rerun-ci

hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor],
Copy link
Collaborator

@ispobock ispobock Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are too many changes for introducing llama_4_scaling parameter here. Is there a better way to handle it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcampora The ci failed: https://github.com/sgl-project/sglang/actions/runs/19900521834/job/57100360299?pr=14213#step:5:5669

File "/public_sglang_ci/runner-l3c-gpu-67/_work/sglang/sglang/python/sglang/srt/models/kimi_linear.py", line 437, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: DeepseekV2AttentionMLA.forward() missing 1 required positional argument: 'llama_4_scaling'

k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we add this logic here, we will have to do the same thing for a ton of attn backends and sync them. thus one way maybe just put logic in the models folder (deepseek_v2.py or mistral py etc), since looks like it is just scaling the q tensor before entering the core attention logic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Mistral's implementation in vLLM, the scaling is applied between RoPE and attention, so it implements inside the mla layer: https://github.com/vllm-project/vllm/pull/29757/files#diff-6ffcb4f51daf85df32c7d35433c3393f1602663960f677ae61f55af1ed3ab524

In SGLang, for trtllm mla backend, the RoPE is fused into the attention backend. This is quite different from vLLM. The scaling is needed to be passed into the backend to apply correctly.

@ispobock
Copy link
Collaborator

ispobock commented Dec 4, 2025

/tag-and-rerun-ci

@ispobock ispobock merged commit 8428078 into sgl-project:main Dec 4, 2025
282 of 297 checks passed
yingluosanqian pushed a commit to yingluosanqian/sglang that referenced this pull request Dec 4, 2025
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
yuchengz816-bot pushed a commit to yuchengz816-bot/sglang that referenced this pull request Dec 8, 2025
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Kevin-XiongC pushed a commit to novitalabs/sglang that referenced this pull request Dec 9, 2025
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
if not (per_tensor or per_channel):
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
assert self.weight_quant.dynamic is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.weight_quant.dynamic is false, so it can be None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mostly from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py#L667-L681
Given compressed tensors are also published by them, would you confirm with them?

Copy link

@Wangzheee Wangzheee Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mostly from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py#L667-L681 Given compressed tensors are also published by them, would you confirm with them?

like this recipe https://github.com/vllm-project/llm-compressor/blob/aa504491afd28a0d5f66d3e38088352dcb4e63ff/src/llmcompressor/modifiers/quantization/gptq/base.py#L57, there is no attribute of dynamic
By the way, can you help review this PR #15386

GuoYechang pushed a commit to GuoYechang/sglang that referenced this pull request Jan 13, 2026
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] bench_sglang fails due to get_model_info endpoint of SGLang PDRouter not being implemented