[Mlir-commits] [mlir] [mlir][arith] Add canonicalize pattern for max/min with constants (PR #161057)
Ziliang Zhang
llvmlistbot at llvm.org
Sun Sep 28 01:12:24 PDT 2025
https://github.com/ziliangzl created https://github.com/llvm/llvm-project/pull/161057
Add canonicalization patterns for nested min/max operations with constants, e.g.:
max(max(x, c0), c1) -> max(x, max(c0, c1))
min(min(x, c0), c1) -> min(x, min(c0, c1))
Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
>From c4ae877f89a7d6ad8144f146815d4458ff31aac2 Mon Sep 17 00:00:00 2001
From: Ziliang Zhang <zzl.coding at gmail.com>
Date: Sun, 28 Sep 2025 12:31:43 +0800
Subject: [PATCH] [mlir][arith] Add canonicalize pattern for max/min with
constants Add canonicalization patterns for nested min/max operations with
constants, e.g.:
max(max(x, c0), c1) -> max(x, max(c0, c1))
min(min(x, c0), c1) -> min(x, min(c0, c1))
Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 4 +
.../Dialect/Arith/IR/ArithCanonicalization.td | 68 +++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 40 ++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 97 ++++++++++++++++++-
4 files changed, 208 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..739d0439c4bba 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
let summary = "signed integer maximum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
let summary = "unsigned integer maximum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
let summary = "signed integer minimum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
let summary = "unsigned integer minimum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9fe3506..ef57af86f0540 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+// Select signed min value of two integer attributes and store to the result
+def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned min value of two integer attributes and store to the result
+def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select signed max value of two integer attributes and store to the result
+def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned max value of two integer attributes and store to the result
+def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">;
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
@@ -202,6 +214,62 @@ def MulUIExtendedToMulI :
[(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+// maxsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1))
+def MaxSIMaxSIConstant :
+ Pat<(Arith_MaxSIOp:$res
+ (Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+// maxui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1))
+def MaxUIMaxUIConstant :
+ Pat<(Arith_MaxUIOp:$res
+ (Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+// minsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1))
+def MinSIMinSIConstant :
+ Pat<(Arith_MinSIOp:$res
+ (Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+// minui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1))
+def MinUIMinUIConstant :
+ Pat<(Arith_MinUIOp:$res
+ (Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>;
+
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..82270ab64f7ec 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin);
+}
+
+static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin);
+}
+
+static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax);
+}
+
+static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax);
+}
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
@@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MaxSIMaxSIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
@@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MaxUIMaxUIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MinimumFOp
//===----------------------------------------------------------------------===//
@@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MinSIMinSIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
@@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MinUIMinUIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..1848decc2eb7c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 {
// -----
+// CHECK-LABEL: @maxsiMaxsiConst1
+// CHECK: %[[C42:.+]] = arith.constant 42 : i32
+// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32
+// CHECK: return %[[RES]]
+func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 {
+ %c17 = arith.constant 17 : i32
+ %c42 = arith.constant 42 : i32
+ %max1 = arith.maxsi %arg0, %c17 : i32
+ %max2 = arith.maxsi %max1, %c42 : i32
+ return %max2 : i32
+}
+
+// CHECK-LABEL: @maxsiMaxsiConst2
+// CHECK: %[[C21:.+]] = arith.constant 21 : i32
+// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32
+// CHECK: return %[[RES]]
+func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 {
+ %c7 = arith.constant 7 : i32
+ %c21 = arith.constant 21 : i32
+ %max1 = arith.maxsi %arg0, %c7 : i32
+ %max2 = arith.maxsi %c21, %max1 : i32
+ return %max2 : i32
+}
+
// CHECK-LABEL: test_maxsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
@@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @maxuiMaxuiConst1
+// CHECK: %[[C42:.+]] = arith.constant 42 : index
+// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index
+// CHECK: return %[[RES]]
+func.func @maxuiMaxuiConst1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %max1 = arith.maxui %arg0, %c17 : index
+ %max2 = arith.maxui %max1, %c42 : index
+ return %max2 : index
+}
+
+// CHECK-LABEL: @maxuiMaxuiConst2
+// CHECK: %[[C21:.+]] = arith.constant 21 : index
+// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index
+// CHECK: return %[[RES]]
+func.func @maxuiMaxuiConst2(%arg0: index) -> index {
+ %c7 = arith.constant 7 : index
+ %c21 = arith.constant 21 : index
+ %max1 = arith.maxui %arg0, %c7 : index
+ %max2 = arith.maxui %c21, %max1 : index
+ return %max2 : index
+}
+
// CHECK-LABEL: test_maxui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
@@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @minsiMinsiConst1
+// CHECK: %[[C17:.+]] = arith.constant 17 : i32
+// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32
+// CHECK: return %[[RES]]
+func.func @minsiMinsiConst1(%arg0: i32) -> i32 {
+ %c17 = arith.constant 17 : i32
+ %c42 = arith.constant 42 : i32
+ %min1 = arith.minsi %arg0, %c17 : i32
+ %min2 = arith.minsi %min1, %c42 : i32
+ return %min2 : i32
+}
+
+// CHECK-LABEL: @minsiMinsiConst2
+// CHECK: %[[C7:.+]] = arith.constant 7 : i32
+// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32
+// CHECK: return %[[RES]]
+func.func @minsiMinsiConst2(%arg0: i32) -> i32 {
+ %c7 = arith.constant 7 : i32
+ %c21 = arith.constant 21 : i32
+ %min1 = arith.minsi %arg0, %c7 : i32
+ %min2 = arith.minsi %c21, %min1 : i32
+ return %min2 : i32
+}
+
// CHECK-LABEL: test_minsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
@@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @minuiMinuiConst1
+// CHECK: %[[C17:.+]] = arith.constant 17 : index
+// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index
+// CHECK: return %[[RES]]
+func.func @minuiMinuiConst1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %min1 = arith.minui %arg0, %c17 : index
+ %min2 = arith.minui %min1, %c42 : index
+ return %min2 : index
+}
+
+// CHECK-LABEL: @minuiMinuiConst2
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index
+// CHECK: return %[[RES]]
+func.func @minuiMinuiConst2(%arg0: index) -> index {
+ %c7 = arith.constant 7 : index
+ %c21 = arith.constant 21 : index
+ %min1 = arith.minui %arg0, %c7 : index
+ %min2 = arith.minui %c21, %min1 : index
+ return %min2 : index
+}
+
// CHECK-LABEL: test_minui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
@@ -3377,4 +3473,3 @@ func.func @unreachable() {
%add = arith.addi %add, %c1_i64 : i64
cf.br ^unreachable
}
-
More information about the Mlir-commits
mailing list