Skip to content

Conversation

aleozlx
Copy link
Collaborator

@aleozlx aleozlx commented Aug 15, 2025

📌 Description

Followed this but for fp8
#1475

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an autotuner specifically designed for FP8 Mixture-of-Experts (MoE) operations within the trtllm-gen framework. It extends existing autotuning capabilities by enabling automatic selection of optimal kernel configurations for various FP8 MoE variants, including both per-tensor and block-scale quantization. The changes involve significant refactoring of the autotuner's core API to support more complex dynamic tensor specifications and custom initialization during profiling, alongside modifications to the underlying C++ kernels to allow external configuration control. A new benchmark script has been added to validate the performance improvements across different quantization modes.

Highlights

  • FP8 MoE Autotuning: Implemented a dedicated MoERunner for FP8 Mixture-of-Experts (MoE) operations, enabling automatic selection of optimal kernel configurations based on input shapes and data types.
  • Flexible Autotuner API: Enhanced the core autotuner framework to support more complex tuning scenarios, including linking multiple dynamic tensor dimensions and providing custom tensor initialization during profiling.
  • C++ Kernel Integration: Modified C++ MoE kernels to accept a config_index parameter, allowing the Python autotuner to directly control kernel selection. New C++ functions were added to expose valid kernel configurations.
  • Comprehensive Benchmarking and Validation: Introduced a new benchmark script to evaluate the performance gains of the autotuner across different FP8 and FP4 MoE quantization modes, demonstrating its effectiveness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an autotuner for FP8 MoE kernels and refactors the autotuning framework to be more general and extensible. The changes are well-structured, adding support for configurable kernel selection at runtime to optimize performance. My feedback focuses on improving maintainability by addressing code duplication and enhancing readability by replacing magic numbers with named constants.

Comment on lines +1136 to 1161
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_,
int64_t const dtype_weights_, bool const useDeepSeekFp8,
int64_t const top_k, int64_t const hidden_size,
int64_t const intermediate_size,
int64_t const num_local_experts, int64_t const num_tokens) {
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
num_local_experts, num_tokens);
}

std::vector<int64_t> trtllm_get_valid_moe_configs(
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
num_tokens);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions trtllm_get_default_moe_configs and trtllm_get_valid_moe_configs contain duplicated code for creating the tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner. To improve maintainability and reduce redundancy, consider extracting the runner creation logic into a common helper function. This helper could then be called by both functions.

Comment on lines +70 to +82
def __hash__(self) -> int:
# FIXME: currently not hasing tensor_initializers
return hash(
(
self.input_idx,
self.dim_idx,
# For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id
self.gen_tuning_buckets
if isinstance(self.gen_tuning_buckets, tuple)
else id(self.gen_tuning_buckets),
id(self.map_to_tuning_buckets),
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The DynamicTensorSpec class has been significantly improved to handle multiple tensors and custom initializers. However, the custom __hash__ method does not include tensor_initializers. While the FIXME comment acknowledges this, it's worth highlighting that this could lead to incorrect caching behavior if different initializers are used with otherwise identical specs. If tensor_initializers can affect performance or correctness, they should be part of the hash key. If they are complex to hash, you could consider hashing a representation of them, like their __name__ or id().

Comment on lines +107 to +115
if dtype in [
DtypeTrtllmGen.MxE4m3,
DtypeTrtllmGen.E2m1,
DtypeTrtllmGen.MxE2m1,
DtypeTrtllmGen.MxE4m3,
]:
return True
else:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of dtypes contains a duplicate DtypeTrtllmGen.MxE4m3. Using a set for the membership check is more idiomatic and efficient, and it also naturally handles the duplicate entry. The entire if/else block can also be simplified to a single return statement.

    return dtype in {
        DtypeTrtllmGen.MxE4m3,
        DtypeTrtllmGen.E2m1,
        DtypeTrtllmGen.MxE2m1,
    }

# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imbalance_factor is a magic number. Consider defining it as a named constant at the class or module level (e.g., IMBALANCE_FACTOR = 1.3) to improve readability and make it clear what this value represents. The accompanying comments already provide a good explanation for it.

Comment on lines +23 to +42
def get_tile_tokens_dim(num_tokens, num_experts, top_k):
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This get_tile_tokens_dim function is a duplicate of the one defined in flashinfer.fused_moe.core.MoERunner. To avoid code duplication and improve maintainability, consider moving this function to a shared utility module (e.g., flashinfer.fused_moe.utils) and importing it in both places.

Comment on lines +54 to +55
TypeCmp compactTmp;
memcpy(&compactTmp, &valueBits, sizeof(valueBits));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change from reinterpret_cast to memcpy is a good practice to avoid potential strict aliasing violations, improving code correctness and robustness.

@yzh119 yzh119 mentioned this pull request Aug 19, 2025
5 tasks
Copy link
Contributor

@IwakuraRein IwakuraRein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.to(torch.bfloat16), breaks the DeepSeekV3 routing. After removing it, there is no need to limit the number of tokens to [8...1024]

else:
# FP8 per tensor scale
return moe_op.trtllm_fp8_per_tensor_scale_moe(
routing_logits.to(torch.bfloat16),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
routing_logits.to(torch.bfloat16),
routing_logits,

else:
# FP4 operations
return moe_op.trtllm_fp4_block_scale_moe(
routing_logits.to(torch.bfloat16),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
routing_logits.to(torch.bfloat16),
routing_logits,

DynamicTensorSpec(
(0, 1, 2, 3, 4, 5),
(0, 0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),

DynamicTensorSpec(
(0, 1, 2, 3, 4),
(0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),

Comment on lines +976 to +977
get_last_power_of_2_num_tokens_buckets(1024, 8),
lambda x: min(last_positive_power_of_2(x), 1024),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
get_last_power_of_2_num_tokens_buckets(1024, 8),
lambda x: min(last_positive_power_of_2(x), 1024),
get_last_power_of_2_num_tokens_buckets(8192, 1),
lambda x: min(last_positive_power_of_2(x), 8192),

enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
tune_max_num_tokens: Maximum number of tokens for tuning. (default: 1024)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune_max_num_tokens: Maximum number of tokens for tuning. (default: 1024)
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)

@@ -1367,6 +1935,7 @@ def trtllm_fp4_block_scale_moe(
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 1024,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune_max_num_tokens: int = 1024,
tune_max_num_tokens: int = 8192,

@@ -1481,6 +2052,7 @@ def trtllm_fp4_block_scale_routed_moe(
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 1024,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune_max_num_tokens: int = 1024,
tune_max_num_tokens: int = 8192,

@@ -1410,6 +1979,7 @@ def trtllm_fp4_block_scale_moe(
- 3: Llama4 (Top1 -> Sigmoid)
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
do_finalize (bool): Whether to finalize the output (default: False)
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)

@@ -1526,6 +2098,7 @@ def trtllm_fp4_block_scale_routed_moe(
- 3: Llama4 (Top1 -> Sigmoid)
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
do_finalize (bool): Whether to finalize the output (default: False)
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024)
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)

@IwakuraRein
Copy link
Contributor

Thanks for the followup pr!

@yzh119
Copy link
Collaborator

yzh119 commented Aug 23, 2025

Do we have any updates on this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants