[Mlir-commits] [mlir] 5ebbc25 - [mlir][ArithToAMDGPU][NFC] Add PatternBenefit (#150091)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 07:56:08 PDT 2025


Author: Krzysztof Drewniak
Date: 2025-07-23T07:56:05-07:00
New Revision: 5ebbc258d4f410c45f247eb53bc722798b4d4f45

URL: https://github.com/llvm/llvm-project/commit/5ebbc258d4f410c45f247eb53bc722798b4d4f45
DIFF: https://github.com/llvm/llvm-project/commit/5ebbc258d4f410c45f247eb53bc722798b4d4f45.diff

LOG: [mlir][ArithToAMDGPU][NFC] Add PatternBenefit (#150091)

Since there may be caseses where these patterns are run alongside the
generic patterns from ArithExpandOps, add a PatternBenefit argument to
allow these architecture-specific patterns to be prioritized.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
    mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 28fdc234e5ef0..f4a9518839224 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -10,6 +10,7 @@
 #define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 
 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
+#include "mlir/IR/PatternMatch.h"
 #include <memory>
 #include <string>
 
@@ -31,7 +32,8 @@ void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
                                              bool convertFP8Arithmetic,
                                              bool saturateFP8Truncf,
                                              bool allowPackedF16Rtz,
-                                             amdgpu::Chipset chipset);
+                                             amdgpu::Chipset chipset,
+                                             PatternBenefit benefit = 1);
 } // namespace arith
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 156c679c5039e..5407dcdedbdff 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
 
   Chipset chipset;
-  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
-      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
+                             PatternBenefit benefit)
+      : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
 
   LogicalResult matchAndRewrite(arith::ExtFOp op,
                                 PatternRewriter &rewriter) const override;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
   bool saturateFP8 = false;
   TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
-                               Chipset chipset)
-      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
-        chipset(chipset) {}
+                               Chipset chipset, PatternBenefit benefit)
+      : OpRewritePattern::OpRewritePattern(ctx, benefit),
+        saturateFP8(saturateFP8), chipset(chipset) {}
   Chipset chipset;
 
   LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
     : OpRewritePattern<arith::ScalingExtFOp> {
   using OpRewritePattern::OpRewritePattern;
 
-  ScalingExtFRewritePattern(MLIRContext *ctx)
-      : OpRewritePattern::OpRewritePattern(ctx) {}
-
   LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
                                 PatternRewriter &rewriter) const override;
 };
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
     : OpRewritePattern<arith::ScalingTruncFOp> {
   using OpRewritePattern::OpRewritePattern;
 
-  ScalingTruncFRewritePattern(MLIRContext *ctx)
-      : OpRewritePattern::OpRewritePattern(ctx) {}
-
   LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
                                 PatternRewriter &rewriter) const override;
 };
@@ -667,19 +662,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
     RewritePatternSet &patterns, bool convertFP8Arithmetic,
-    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
+    PatternBenefit benefit) {
 
   if (convertFP8Arithmetic) {
-    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
-    patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
-                                               saturateFP8Truncf, chipset);
+    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
+                                             benefit);
+    patterns.add<TruncFToFloat8RewritePattern>(
+        patterns.getContext(), saturateFP8Truncf, chipset, benefit);
   }
   if (allowPackedF16Rtz)
-    patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+    patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
 
   if (chipset >= kGfx950) {
-    patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
-    patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
+    patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
+    patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
   }
 }
 


        


More information about the Mlir-commits mailing list