[Mlir-commits] [mlir] [mlir][arith] Add canonicalize pattern for max/min with constants (PR #161057)
    Ziliang Zhang 
    llvmlistbot at llvm.org
       
    Sun Sep 28 01:12:24 PDT 2025
    
    
  
https://github.com/ziliangzl created https://github.com/llvm/llvm-project/pull/161057
Add canonicalization patterns for nested min/max operations with constants, e.g.:
  max(max(x, c0), c1) -> max(x, max(c0, c1))
  min(min(x, c0), c1) -> min(x, min(c0, c1))
Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
>From c4ae877f89a7d6ad8144f146815d4458ff31aac2 Mon Sep 17 00:00:00 2001
From: Ziliang Zhang <zzl.coding at gmail.com>
Date: Sun, 28 Sep 2025 12:31:43 +0800
Subject: [PATCH] [mlir][arith] Add canonicalize pattern for max/min with
 constants Add canonicalization patterns for nested min/max operations with
 constants, e.g.:
  max(max(x, c0), c1) -> max(x, max(c0, c1))
  min(min(x, c0), c1) -> min(x, min(c0, c1))
Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  4 +
 .../Dialect/Arith/IR/ArithCanonicalization.td | 68 +++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 40 ++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 97 ++++++++++++++++++-
 4 files changed, 208 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..739d0439c4bba 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
 def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
   let summary = "signed integer maximum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
 def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
   let summary = "unsigned integer maximum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
 def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
   let summary = "signed integer minimum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
 def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
   let summary = "unsigned integer minimum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9fe3506..ef57af86f0540 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+// Select signed min value of two integer attributes and store to the result
+def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned min value of two integer attributes and store to the result
+def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select signed max value of two integer attributes and store to the result
+def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned max value of two integer attributes and store to the result
+def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">;
+
 // Merge overflow flags from 2 ops, selecting the most conservative combination.
 def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
 
@@ -202,6 +214,62 @@ def MulUIExtendedToMulI :
         [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+// maxsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1))
+def MaxSIMaxSIConstant :
+    Pat<(Arith_MaxSIOp:$res
+          (Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+// maxui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1))
+def MaxUIMaxUIConstant :
+    Pat<(Arith_MaxUIOp:$res
+          (Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+// minsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1))
+def MinSIMinSIConstant :
+    Pat<(Arith_MinSIOp:$res
+          (Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+// minui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1))
+def MinUIMinUIConstant :
+    Pat<(Arith_MinUIOp:$res
+          (Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>;
+
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..82270ab64f7ec 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin);
+}
+
+static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin);
+}
+
+static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax);
+}
+
+static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax);
+}
+
 // Merge overflow flags from 2 ops, selecting the most conservative combination.
 static IntegerOverflowFlagsAttr
 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
@@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MaxSIMaxSIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
@@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MaxUIMaxUIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MinimumFOp
 //===----------------------------------------------------------------------===//
@@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MinSIMinSIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
@@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MinUIMinUIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..1848decc2eb7c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 {
 
 // -----
 
+// CHECK-LABEL: @maxsiMaxsiConst1
+//       CHECK:   %[[C42:.+]] = arith.constant 42 : i32
+//       CHECK:   %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32
+//       CHECK:   return %[[RES]]
+func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 {
+  %c17 = arith.constant 17 : i32
+  %c42 = arith.constant 42 : i32
+  %max1 = arith.maxsi %arg0, %c17 : i32
+  %max2 = arith.maxsi %max1, %c42 : i32
+  return %max2 : i32
+}
+
+// CHECK-LABEL: @maxsiMaxsiConst2
+//       CHECK:   %[[C21:.+]] = arith.constant 21 : i32
+//       CHECK:   %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32
+//       CHECK:   return %[[RES]]
+func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 {
+  %c7 = arith.constant 7 : i32
+  %c21 = arith.constant 21 : i32
+  %max1 = arith.maxsi %arg0, %c7 : i32
+  %max2 = arith.maxsi %c21, %max1 : i32
+  return %max2 : i32
+}
+
 // CHECK-LABEL: test_maxsi
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
@@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @maxuiMaxuiConst1
+//       CHECK:   %[[C42:.+]] = arith.constant 42 : index
+//       CHECK:   %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index
+//       CHECK:   return %[[RES]]
+func.func @maxuiMaxuiConst1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %max1 = arith.maxui %arg0, %c17 : index
+  %max2 = arith.maxui %max1, %c42 : index
+  return %max2 : index
+}
+
+// CHECK-LABEL: @maxuiMaxuiConst2
+//       CHECK:   %[[C21:.+]] = arith.constant 21 : index
+//       CHECK:   %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index
+//       CHECK:   return %[[RES]]
+func.func @maxuiMaxuiConst2(%arg0: index) -> index {
+  %c7 = arith.constant 7 : index
+  %c21 = arith.constant 21 : index
+  %max1 = arith.maxui %arg0, %c7 : index
+  %max2 = arith.maxui %c21, %max1 : index
+  return %max2 : index
+}
+
 // CHECK-LABEL: test_maxui
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
@@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @minsiMinsiConst1
+//       CHECK:   %[[C17:.+]] = arith.constant 17 : i32
+//       CHECK:   %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32
+//       CHECK:   return %[[RES]]
+func.func @minsiMinsiConst1(%arg0: i32) -> i32 {
+  %c17 = arith.constant 17 : i32
+  %c42 = arith.constant 42 : i32
+  %min1 = arith.minsi %arg0, %c17 : i32
+  %min2 = arith.minsi %min1, %c42 : i32
+  return %min2 : i32
+}
+
+// CHECK-LABEL: @minsiMinsiConst2
+//       CHECK:   %[[C7:.+]] = arith.constant 7 : i32
+//       CHECK:   %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32
+//       CHECK:   return %[[RES]]
+func.func @minsiMinsiConst2(%arg0: i32) -> i32 {
+  %c7 = arith.constant 7 : i32
+  %c21 = arith.constant 21 : i32
+  %min1 = arith.minsi %arg0, %c7 : i32
+  %min2 = arith.minsi %c21, %min1 : i32
+  return %min2 : i32
+}
+
 // CHECK-LABEL: test_minsi
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
@@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @minuiMinuiConst1
+//       CHECK:   %[[C17:.+]] = arith.constant 17 : index
+//       CHECK:   %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index
+//       CHECK:   return %[[RES]]
+func.func @minuiMinuiConst1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %min1 = arith.minui %arg0, %c17 : index
+  %min2 = arith.minui %min1, %c42 : index
+  return %min2 : index
+}
+
+// CHECK-LABEL: @minuiMinuiConst2
+//       CHECK:   %[[C7:.+]] = arith.constant 7 : index
+//       CHECK:   %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index
+//       CHECK:   return %[[RES]]
+func.func @minuiMinuiConst2(%arg0: index) -> index {
+  %c7 = arith.constant 7 : index
+  %c21 = arith.constant 21 : index
+  %min1 = arith.minui %arg0, %c7 : index
+  %min2 = arith.minui %c21, %min1 : index
+  return %min2 : index
+}
+
 // CHECK-LABEL: test_minui
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
@@ -3377,4 +3473,3 @@ func.func @unreachable() {
   %add = arith.addi %add, %c1_i64 : i64
   cf.br ^unreachable
 }
-
    
    
More information about the Mlir-commits
mailing list