[Mlir-commits] [mlir] [MLIR] Add trivial simplifications for affine mod, div, ceil (PR #182234)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 19 00:06:20 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Add missing trivial folding rules for div and mod affine expressions when the LHS and RHS were the same.

---
Full diff: https://github.com/llvm/llvm-project/pull/182234.diff


2 Files Affected:

- (modified) mlir/lib/IR/AffineExpr.cpp (+18-6) 
- (modified) mlir/test/IR/affine-map.mlir (+18) 


``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 1cb3163c4ba94..da91066815ca4 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -912,13 +912,17 @@ AffineExpr AffineExpr::operator-(AffineExpr other) const {
 }
 
 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
+  // For the defined cases, simplify x floordiv x is 1.
+  if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+    return getAffineConstantExpr(1, lhs.getContext());
+
+  // All other simplifications further below are for the RHS constant case.
   if (!rhsConst || rhsConst.getValue() == 0)
     return nullptr;
 
-  if (lhsConst) {
+  if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
     if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
       return nullptr;
     return getAffineConstantExpr(
@@ -971,13 +975,17 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
 }
 
 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
+  // For the defined cases, simplify x ceildiv x is 1.
+  if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+    return getAffineConstantExpr(1, lhs.getContext());
+
+  // All other simplifications further below are for the RHS constant case.
   if (!rhsConst || rhsConst.getValue() == 0)
     return nullptr;
 
-  if (lhsConst) {
+  if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
     if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
       return nullptr;
     return getAffineConstantExpr(
@@ -1018,14 +1026,18 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
 }
 
 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
+  // For the defined cases, simplify x % x to 0.
+  if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
+    return getAffineConstantExpr(0, lhs.getContext());
+
   // mod w.r.t zero or negative numbers is undefined and preserved as is.
+  // All other simplifications further below are for the RHS constant case.
   if (!rhsConst || rhsConst.getValue() < 1)
     return nullptr;
 
-  if (lhsConst) {
+  if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
     // mod never overflows.
     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
                                  lhs.getContext());
diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 129e83535b4a4..86bdaafd79f32 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -219,6 +219,15 @@
 // CHECK: #map{{[0-9]*}} = affine_map<(d0) -> (d0 + d0 floordiv 4 - 4)>
 #map66 = affine_map<(d0) -> (d0 + ((d0 floordiv 4) - 4))>
 
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (1)>
+#map67 = affine_map<()[s0, s1] -> ((s0 + s1) floordiv (s0 + s1))>
+
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (2)>
+#map68 = affine_map<()[s0, s1] -> ((s0 + s1) ceildiv (s0 + s1) * 2)>
+
+// CHECK: #map{{[0-9]*}} = affine_map<()[s0, s1] -> (0)>
+#map69 = affine_map<()[s0, s1] -> ((s0 + s1) mod (s0 + s1))>
+
 // Single identity maps are removed.
 // CHECK: @f0(memref<2x4xi8, 1>)
 func.func private @f0(memref<2x4xi8, #map0, 1>)
@@ -430,3 +439,12 @@ func.func private @f56(memref<1x1xi8, #map56>)
 
 // CHECK: "f66"() {map = #map{{[0-9]*}}} : () -> ()
 "f66"() {map = #map66} : () -> ()
+
+// CHECK: "f67"() {map = #map{{[0-9]*}}} : () -> ()
+"f67"() {map = #map67} : () -> ()
+
+// CHECK: "f68"() {map = #map{{[0-9]*}}} : () -> ()
+"f68"() {map = #map68} : () -> ()
+
+// CHECK: "f69"() {map = #map{{[0-9]*}}} : () -> ()
+"f69"() {map = #map69} : () -> ()

``````````

</details>


https://github.com/llvm/llvm-project/pull/182234


More information about the Mlir-commits mailing list