[Mlir-commits] [mlir] Lower affine modulo by powers of two using bitwise AND (PR #146311)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 29 22:53:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Yuxi Sun (sherylll)

<details>
<summary>Changes</summary>

This patch adds a special-case optimization in the affine-to-standard lowering pass to replace modulo operations by constant powers of two with a single bitwise AND operation. This reduces instruction count and improves performance for common cases like `x mod 2`.


---
Full diff: https://github.com/llvm/llvm-project/pull/146311.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+12) 
- (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 66b3f2a4f93a5..de9c7874767e4 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -80,12 +80,24 @@ class AffineApplyExpander
   ///         let remainder = srem a, b;
   ///             negative = a < 0 in
   ///         select negative, remainder + b, remainder.
+  ///
+  /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
   Value visitModExpr(AffineBinaryOpExpr expr) {
     if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
       if (rhsConst.getValue() <= 0) {
         emitError(loc, "modulo by non-positive value is not supported");
         return nullptr;
       }
+      
+      // Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
+      int64_t rhsValue = rhsConst.getValue();
+      if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
+        auto lhs = visit(expr.getLHS());
+        assert(lhs && "unexpected affine expr lowering failure");
+        
+        Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
+        return builder.create<arith::AndIOp>(loc, lhs, maskCst);
+      }
     }
 
     auto lhs = visit(expr.getLHS());
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..07f7c64fe6ea5 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 // CHECK:      scf.reduce.return %[[RES]] : i64
 // CHECK:    }
 // CHECK:  }
+
+#map_mod_8 = affine_map<(i) -> (i mod 8)>
+// CHECK-LABEL: func @affine_apply_mod_8
+func.func @affine_apply_mod_8(%arg0 : index) -> (index) {
+  // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
+  // CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index
+  %0 = affine.apply #map_mod_8 (%arg0)
+  return %0 : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/146311


More information about the Mlir-commits mailing list