[Mlir-commits] [mlir] [mlir][SCF] Do not peel already peeled loops (PR #71900)

Matthias Springer llvmlistbot at llvm.org
Wed Nov 15 18:42:50 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/71900

>From 0ec5b87a0bf4c1d020bcca18aca741b52450a59a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 16 Nov 2023 11:33:30 +0900
Subject: [PATCH] [mlir][SCF] Do not peel already peeled loops

Loop peeling is not beneficial if the step size already divides "ub - lb". There are currently some simple checks to prevent peeling in such cases when lb, ub, step are constants. This commit adds support for IR that is the result of loop peeling in the general case; i.e., lb, ub, step do not necessarily have to be constants.

This change adds a new affine_map simplification rule for semi-affine maps that appear during loop peeling and are guaranteed to evaluate to a constant zero. Affine maps such as `(ub - ub mod step) mod step`.
---
 .../SCF/Transforms/LoopSpecialization.cpp     |  20 +++-
 mlir/lib/IR/AffineExpr.cpp                    | 100 ++++++++++++++++--
 mlir/test/Dialect/Affine/canonicalize.mlir    |  13 +++
 .../Dialect/SCF/transform-ops-invalid.mlir    |  50 ++++++++-
 4 files changed, 168 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 23646f42eb5fe3d..93b81794ec2652e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -123,19 +123,29 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
   auto ubInt = getConstantIntValue(forOp.getUpperBound());
   auto stepInt = getConstantIntValue(forOp.getStep());
 
-  // No specialization necessary if step already divides upper bound evenly.
-  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
-    return failure();
   // No specialization necessary if step size is 1.
-  if (stepInt == static_cast<int64_t>(1))
+  if (getConstantIntValue(forOp.getStep()) == static_cast<int64_t>(1))
     return failure();
 
-  auto loc = forOp.getLoc();
+  // No specialization necessary if step already divides upper bound evenly.
+  // Fast path: lb, ub and step are constants.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // Slow path: Examine the ops that define lb, ub and step.
   AffineExpr sym0, sym1, sym2;
   bindSymbols(b.getContext(), sym0, sym1, sym2);
+  SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
+                              forOp.getStep()};
+  AffineMap map = AffineMap::get(0, 3, {(sym1 - sym0) % sym2});
+  affine::fullyComposeAffineMapAndOperands(&map, &operands);
+  if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(0)))
+    if (constExpr.getValue() == 0)
+      return failure();
+
   // New upper bound: %ub - (%ub - %lb) mod %step
   auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
   b.setInsertionPoint(forOp);
+  auto loc = forOp.getLoc();
   splitBound = b.createOrFold<AffineApplyOp>(loc, modMap,
                                              ValueRange{forOp.getLowerBound(),
                                                         forOp.getUpperBound(),
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c9be1af034f8e3f..e550290282b2588 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -444,12 +444,89 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
   llvm_unreachable("Unknown AffineExpr");
 }
 
+/// Populate `result` with all summand operands of given (potentially nested)
+/// addition. If the given expression is not an addition, just populate the
+/// expression itself.
+/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
+static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
+  auto addExpr = expr.dyn_cast<AffineBinaryOpExpr>();
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
+    result.push_back(expr);
+    return;
+  }
+  getSummandExprs(addExpr.getLHS(), result);
+  getSummandExprs(addExpr.getRHS(), result);
+}
+
+/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
+/// If so, also return the non-negated expression via `expr`.
+static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
+  auto mulExpr = candidate.dyn_cast<AffineBinaryOpExpr>();
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+  if (auto lhs = mulExpr.getLHS().dyn_cast<AffineConstantExpr>()) {
+    if (lhs.getValue() == -1) {
+      expr = mulExpr.getRHS();
+      return true;
+    }
+  }
+  if (auto rhs = mulExpr.getRHS().dyn_cast<AffineConstantExpr>()) {
+    if (rhs.getValue() == -1) {
+      expr = mulExpr.getLHS();
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
+/// the fact that `lhs` contains another modulo expression that ensures that
+/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
+/// after loop peeling.
+///
+/// Example: lhs = ub - ub % step
+///          rhs = step
+///       => (ub - ub % step) % step is guaranteed to evaluate to 0.
+static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
+                                  unsigned numDims, unsigned numSymbols) {
+  // TODO: Try to unify this function with `getBoundForAffineExpr`.
+  // Collect all summands in lhs.
+  SmallVector<AffineExpr> summands;
+  getSummandExprs(lhs, summands);
+  // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
+  // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
+  for (int64_t i = 0, e = summands.size(); i < e; ++i) {
+    AffineExpr current = summands[i];
+    AffineExpr beforeNegation;
+    if (!isNegatedAffineExpr(current, beforeNegation))
+      continue;
+    AffineBinaryOpExpr innerMod = beforeNegation.dyn_cast<AffineBinaryOpExpr>();
+    if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
+      continue;
+    if (innerMod.getRHS() != rhs)
+      continue;
+    // Sum all remaining summands and subtract x. If that expression can be
+    // simplified to zero, then the remaining summands and x are equal.
+    AffineExpr diff = getAffineConstantExpr(0, lhs.getContext());
+    for (int64_t j = 0; j < e; ++j)
+      if (i != j)
+        diff = diff + summands[j];
+    diff = diff - innerMod.getLHS();
+    diff = simplifyAffineExpr(diff, numDims, numSymbols);
+    auto constExpr = dyn_cast<AffineConstantExpr>(diff);
+    if (constExpr && constExpr.getValue() == 0)
+      return true;
+  }
+  return false;
+}
+
 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
 /// operations when the second operand simplifies to a symbol and the first
 /// operand is divisible by that symbol. It can be applied to any semi-affine
 /// expression. Returned expression can either be a semi-affine or pure affine
 /// expression.
-static AffineExpr simplifySemiAffine(AffineExpr expr) {
+static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
+                                     unsigned numSymbols) {
   switch (expr.getKind()) {
   case AffineExprKind::Constant:
   case AffineExprKind::DimId:
@@ -458,9 +535,10 @@ static AffineExpr simplifySemiAffine(AffineExpr expr) {
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return getAffineBinaryOpExpr(expr.getKind(),
-                                 simplifySemiAffine(binaryExpr.getLHS()),
-                                 simplifySemiAffine(binaryExpr.getRHS()));
+    return getAffineBinaryOpExpr(
+        expr.getKind(),
+        simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
+        simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
   }
   // Check if the simplification of the second operand is a symbol, and the
   // first operand is divisible by it. If the operation is a modulo, a constant
@@ -471,10 +549,14 @@ static AffineExpr simplifySemiAffine(AffineExpr expr) {
   case AffineExprKind::CeilDiv:
   case AffineExprKind::Mod: {
     AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
-    AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
-    AffineSymbolExpr symbolExpr =
-        dyn_cast<AffineSymbolExpr>(simplifySemiAffine(binaryExpr.getRHS()));
+    AffineExpr sLHS =
+        simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
+    AffineExpr sRHS =
+        simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
+    if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
+      return getAffineConstantExpr(0, expr.getContext());
+    AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
+        simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
     if (!symbolExpr)
       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
     unsigned symbolPos = symbolExpr.getPosition();
@@ -1415,7 +1497,7 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
                                     unsigned numSymbols) {
   // Simplify semi-affine expressions separately.
   if (!expr.isPureAffine())
-    expr = simplifySemiAffine(expr);
+    expr = simplifySemiAffine(expr, numDims, numSymbols);
 
   SimpleAffineExprFlattener flattener(numDims, numSymbols);
   flattener.walkPostOrder(expr);
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index f734f37b0dc1a97..7c0930eedc85683 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1439,3 +1439,16 @@ func.func @max.oneval(%arg0: index) -> index {
   // CHECK: return %arg0 : index
   return %max: index
 }
+
+// -----
+
+// CHECK-LABEL: func @mod_of_mod(
+//       CHECK:   %[[c0:.*]] = arith.constant 0
+//       CHECK:   return %[[c0]], %[[c0]]
+func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
+  // Simplify: (ub - ub % step) % step == 0
+  %0 = affine.apply affine_map<()[s0, s1] -> ((s0 - (s0 mod s1)) mod s1)> ()[%ub, %step]
+  // Simplify: (ub - (ub - lb) % step - lb) % step == 0
+  %1 = affine.apply affine_map<()[s0, s1, s2] -> ((s0 - ((s0 - s2) mod s1) - s2) mod s1)> ()[%ub, %step, %lb]
+  return %0, %1 : index, index
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index 59b824d4ca26209..a01a7995d6189e1 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -68,7 +68,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test_loops_do_not_get_peeled() {
+func.func @test_loop_peeling_not_beneficial() {
+  // Loop peeling is not beneficial because the step size already divides
+  // ub - lb evenly. lb, ub and step are constant in this test case and the
+  // "fast path" is exercised.
   %lb = arith.constant 0 : index
   %ub = arith.constant 40 : index
   %step = arith.constant 5 : index
@@ -87,3 +90,48 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @test_loop_peeling_not_beneficial_already_peeled(%lb: index, %ub: index, %step: index) {
+  // Loop peeling is not beneficial because the step size already divides
+  // ub - lb evenly. This test case exercises the "slow path".
+  %new_ub = affine.apply affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>()[%lb, %ub, %step]
+  scf.for %i = %lb to %new_ub step %step {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to peel}}
+    transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @test_loop_peeling_not_beneficial_already_peeled_lb_zero(%ub: index, %step: index) {
+  // Loop peeling is not beneficial because the step size already divides
+  // ub - lb evenly. This test case exercises the "slow path".
+  %lb = arith.constant 0 : index
+  %new_ub = affine.apply affine_map<()[s1, s2] -> (s1 - s1 mod s2)>()[%ub, %step]
+  scf.for %i = %lb to %new_ub step %step {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to peel}}
+    transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list