-
Notifications
You must be signed in to change notification settings - Fork 479
Trtllm-gen Fp8 MoE Autotunner #1494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces an autotuner specifically designed for FP8 Mixture-of-Experts (MoE) operations within the trtllm-gen
framework. It extends existing autotuning capabilities by enabling automatic selection of optimal kernel configurations for various FP8 MoE variants, including both per-tensor and block-scale quantization. The changes involve significant refactoring of the autotuner's core API to support more complex dynamic tensor specifications and custom initialization during profiling, alongside modifications to the underlying C++ kernels to allow external configuration control. A new benchmark script has been added to validate the performance improvements across different quantization modes.
Highlights
- FP8 MoE Autotuning: Implemented a dedicated
MoERunner
for FP8 Mixture-of-Experts (MoE) operations, enabling automatic selection of optimal kernel configurations based on input shapes and data types. - Flexible Autotuner API: Enhanced the core autotuner framework to support more complex tuning scenarios, including linking multiple dynamic tensor dimensions and providing custom tensor initialization during profiling.
- C++ Kernel Integration: Modified C++ MoE kernels to accept a
config_index
parameter, allowing the Python autotuner to directly control kernel selection. New C++ functions were added to expose valid kernel configurations. - Comprehensive Benchmarking and Validation: Introduced a new benchmark script to evaluate the performance gains of the autotuner across different FP8 and FP4 MoE quantization modes, demonstrating its effectiveness.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an autotuner for FP8 MoE kernels and refactors the autotuning framework to be more general and extensible. The changes are well-structured, adding support for configurable kernel selection at runtime to optimize performance. My feedback focuses on improving maintainability by addressing code duplication and enhancing readability by replacing magic numbers with named constants.
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, | ||
int64_t const dtype_weights_, bool const useDeepSeekFp8, | ||
int64_t const top_k, int64_t const hidden_size, | ||
int64_t const intermediate_size, | ||
int64_t const num_local_experts, int64_t const num_tokens) { | ||
auto dtype_act = static_cast<btg::Dtype>(dtype_act_); | ||
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_); | ||
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( | ||
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, | ||
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true); | ||
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, | ||
num_local_experts, num_tokens); | ||
} | ||
|
||
std::vector<int64_t> trtllm_get_valid_moe_configs( | ||
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_, | ||
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, | ||
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) { | ||
auto dtype_act = static_cast<btg::Dtype>(dtype_act_); | ||
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_); | ||
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( | ||
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, | ||
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true); | ||
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, | ||
num_tokens); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions trtllm_get_default_moe_configs
and trtllm_get_valid_moe_configs
contain duplicated code for creating the tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner
. To improve maintainability and reduce redundancy, consider extracting the runner creation logic into a common helper function. This helper could then be called by both functions.
def __hash__(self) -> int: | ||
# FIXME: currently not hasing tensor_initializers | ||
return hash( | ||
( | ||
self.input_idx, | ||
self.dim_idx, | ||
# For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id | ||
self.gen_tuning_buckets | ||
if isinstance(self.gen_tuning_buckets, tuple) | ||
else id(self.gen_tuning_buckets), | ||
id(self.map_to_tuning_buckets), | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DynamicTensorSpec
class has been significantly improved to handle multiple tensors and custom initializers. However, the custom __hash__
method does not include tensor_initializers
. While the FIXME
comment acknowledges this, it's worth highlighting that this could lead to incorrect caching behavior if different initializers are used with otherwise identical specs. If tensor_initializers
can affect performance or correctness, they should be part of the hash key. If they are complex to hash, you could consider hashing a representation of them, like their __name__
or id()
.
if dtype in [ | ||
DtypeTrtllmGen.MxE4m3, | ||
DtypeTrtllmGen.E2m1, | ||
DtypeTrtllmGen.MxE2m1, | ||
DtypeTrtllmGen.MxE4m3, | ||
]: | ||
return True | ||
else: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of dtypes contains a duplicate DtypeTrtllmGen.MxE4m3
. Using a set
for the membership check is more idiomatic and efficient, and it also naturally handles the duplicate entry. The entire if/else
block can also be simplified to a single return statement.
return dtype in {
DtypeTrtllmGen.MxE4m3,
DtypeTrtllmGen.E2m1,
DtypeTrtllmGen.MxE2m1,
}
# - > 1.0 means some experts have more | ||
# tokens than the perfect distribution. | ||
# - < 1.0 does not make sense. | ||
imbalance_factor = 1.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_tile_tokens_dim(num_tokens, num_experts, top_k): | ||
# Factor to account for the imbalance of the experts. | ||
# factor equals to the | ||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert | ||
# - 1.0 means perfect expert distribution. | ||
# - > 1.0 means some experts have more | ||
# tokens than the perfect distribution. | ||
# - < 1.0 does not make sense. | ||
imbalance_factor = 1.3 | ||
# Calculate the number of tokens per expert | ||
# assuming perfect distribution. | ||
num_tokens_per_expert = (num_tokens * top_k) // num_experts | ||
# Apply the imbalance factor. | ||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) | ||
# And pad the number to the next power of 2. | ||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) | ||
# Cap to 8-64 tokens per CTA tile | ||
# as it's the range supported by the kernel. | ||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) | ||
return tile_tokens_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeCmp compactTmp; | ||
memcpy(&compactTmp, &valueBits, sizeof(valueBits)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.to(torch.bfloat16),
breaks the DeepSeekV3 routing. After removing it, there is no need to limit the number of tokens to [8...1024]
else: | ||
# FP8 per tensor scale | ||
return moe_op.trtllm_fp8_per_tensor_scale_moe( | ||
routing_logits.to(torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
routing_logits.to(torch.bfloat16), | |
routing_logits, |
else: | ||
# FP4 operations | ||
return moe_op.trtllm_fp4_block_scale_moe( | ||
routing_logits.to(torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
routing_logits.to(torch.bfloat16), | |
routing_logits, |
DynamicTensorSpec( | ||
(0, 1, 2, 3, 4, 5), | ||
(0, 0, 0, 0, 0, 0), | ||
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), | |
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1), |
DynamicTensorSpec( | ||
(0, 1, 2, 3, 4), | ||
(0, 0, 0, 0, 0), | ||
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), | |
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1), |
get_last_power_of_2_num_tokens_buckets(1024, 8), | ||
lambda x: min(last_positive_power_of_2(x), 1024), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_last_power_of_2_num_tokens_buckets(1024, 8), | |
lambda x: min(last_positive_power_of_2(x), 1024), | |
get_last_power_of_2_num_tokens_buckets(8192, 1), | |
lambda x: min(last_positive_power_of_2(x), 8192), |
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. | ||
tune_max_num_tokens: Maximum number of tokens for tuning. (default: 1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune_max_num_tokens: Maximum number of tokens for tuning. (default: 1024) | |
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) |
@@ -1367,6 +1935,7 @@ def trtllm_fp4_block_scale_moe( | |||
routing_method_type: int = 0, | |||
do_finalize: bool = True, | |||
enable_pdl: Optional[bool] = None, | |||
tune_max_num_tokens: int = 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune_max_num_tokens: int = 1024, | |
tune_max_num_tokens: int = 8192, |
@@ -1481,6 +2052,7 @@ def trtllm_fp4_block_scale_routed_moe( | |||
routing_method_type: int = 0, | |||
do_finalize: bool = True, | |||
enable_pdl: Optional[bool] = None, | |||
tune_max_num_tokens: int = 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune_max_num_tokens: int = 1024, | |
tune_max_num_tokens: int = 8192, |
@@ -1410,6 +1979,7 @@ def trtllm_fp4_block_scale_moe( | |||
- 3: Llama4 (Top1 -> Sigmoid) | |||
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) | |||
do_finalize (bool): Whether to finalize the output (default: False) | |||
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) | |
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) |
@@ -1526,6 +2098,7 @@ def trtllm_fp4_block_scale_routed_moe( | |||
- 3: Llama4 (Top1 -> Sigmoid) | |||
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) | |||
do_finalize (bool): Whether to finalize the output (default: False) | |||
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) | |
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) |
Thanks for the followup pr! |
Do we have any updates on this PR? |
📌 Description
Followed this but for fp8
#1475
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes