[Mlir-commits] [mlir] 5ebbc25 - [mlir][ArithToAMDGPU][NFC] Add PatternBenefit (#150091)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 07:56:08 PDT 2025
Author: Krzysztof Drewniak
Date: 2025-07-23T07:56:05-07:00
New Revision: 5ebbc258d4f410c45f247eb53bc722798b4d4f45
URL: https://github.com/llvm/llvm-project/commit/5ebbc258d4f410c45f247eb53bc722798b4d4f45
DIFF: https://github.com/llvm/llvm-project/commit/5ebbc258d4f410c45f247eb53bc722798b4d4f45.diff
LOG: [mlir][ArithToAMDGPU][NFC] Add PatternBenefit (#150091)
Since there may be caseses where these patterns are run alongside the
generic patterns from ArithExpandOps, add a PatternBenefit argument to
allow these architecture-specific patterns to be prioritized.
Added:
Modified:
mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 28fdc234e5ef0..f4a9518839224 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -10,6 +10,7 @@
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
+#include "mlir/IR/PatternMatch.h"
#include <memory>
#include <string>
@@ -31,7 +32,8 @@ void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool convertFP8Arithmetic,
bool saturateFP8Truncf,
bool allowPackedF16Rtz,
- amdgpu::Chipset chipset);
+ amdgpu::Chipset chipset,
+ PatternBenefit benefit = 1);
} // namespace arith
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 156c679c5039e..5407dcdedbdff 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
- ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
+ PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
- Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
- chipset(chipset) {}
+ Chipset chipset, PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit),
+ saturateFP8(saturateFP8), chipset(chipset) {}
Chipset chipset;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
: OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingExtFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const override;
};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
: OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingTruncFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const override;
};
@@ -667,19 +662,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
+ PatternBenefit benefit) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8Truncf, chipset);
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
+ benefit);
+ patterns.add<TruncFToFloat8RewritePattern>(
+ patterns.getContext(), saturateFP8Truncf, chipset, benefit);
}
if (allowPackedF16Rtz)
- patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
if (chipset >= kGfx950) {
- patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
- patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
+ patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
+ patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
}
}
More information about the Mlir-commits
mailing list