[Mlir-commits] [mlir] [MLIR] Add trivial simplifications for affine mod, div, ceil (PR #182234)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 19 00:06:20 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Uday Bondhugula (bondhugula)
<details>
<summary>Changes</summary>
Add missing trivial folding rules for div and mod affine expressions when the LHS and RHS were the same.
---
Full diff: https://github.com/llvm/llvm-project/pull/182234.diff
2 Files Affected:
- (modified) mlir/lib/IR/AffineExpr.cpp (+18-6)
- (modified) mlir/test/IR/affine-map.mlir (+18)
``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 1cb3163c4ba94..da91066815ca4 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -912,13 +912,17 @@ AffineExpr AffineExpr::operator-(AffineExpr other) const {
}
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+ // For the defined cases, simplify x floordiv x is 1.
+ if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+ return getAffineConstantExpr(1, lhs.getContext());
+
+ // All other simplifications further below are for the RHS constant case.
if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;
- if (lhsConst) {
+ if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
return nullptr;
return getAffineConstantExpr(
@@ -971,13 +975,17 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
}
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+ // For the defined cases, simplify x ceildiv x is 1.
+ if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+ return getAffineConstantExpr(1, lhs.getContext());
+
+ // All other simplifications further below are for the RHS constant case.
if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;
- if (lhsConst) {
+ if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
return nullptr;
return getAffineConstantExpr(
@@ -1018,14 +1026,18 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
}
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+ // For the defined cases, simplify x % x to 0.
+ if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+ return getAffineConstantExpr(0, lhs.getContext());
+
// mod w.r.t zero or negative numbers is undefined and preserved as is.
+ // All other simplifications further below are for the RHS constant case.
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
- if (lhsConst) {
+ if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
// mod never overflows.
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 129e83535b4a4..86bdaafd79f32 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -219,6 +219,15 @@
// CHECK: #map{{[0-9]*}} = affine_map<(d0) -> (d0 + d0 floordiv 4 - 4)>
#map66 = affine_map<(d0) -> (d0 + ((d0 floordiv 4) - 4))>
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (1)>
+#map67 = affine_map<()[s0, s1] -> ((s0 + s1) floordiv (s0 + s1))>
+
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (2)>
+#map68 = affine_map<()[s0, s1] -> ((s0 + s1) ceildiv (s0 + s1) * 2)>
+
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (0)>
+#map69 = affine_map<()[s0, s1] -> ((s0 + s1) mod (s0 + s1))>
+
// Single identity maps are removed.
// CHECK: @f0(memref<2x4xi8, 1>)
func.func private @f0(memref<2x4xi8, #map0, 1>)
@@ -430,3 +439,12 @@ func.func private @f56(memref<1x1xi8, #map56>)
// CHECK: "f66"() {map = #map{{[0-9]*}}} : () -> ()
"f66"() {map = #map66} : () -> ()
+
+// CHECK: "f67"() {map = #map{{[0-9]*}}} : () -> ()
+"f67"() {map = #map67} : () -> ()
+
+// CHECK: "f68"() {map = #map{{[0-9]*}}} : () -> ()
+"f68"() {map = #map68} : () -> ()
+
+// CHECK: "f69"() {map = #map{{[0-9]*}}} : () -> ()
+"f69"() {map = #map69} : () -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/182234
More information about the Mlir-commits
mailing list