[Mlir-commits] [mlir] NFC. Move out and expose affine expression simplification utility out of AffineOps lib (PR #69813)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 22:44:05 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Move out trivial affine expression simplification out of AffineOps library.  Expose it from libIR. Users of such methods shouldn't have to rely on the AffineOps dialect. For eg., with this change, the method can be used now from lib/Analysis/ (FlatLinearConstraints) as well as AffineOps dialect canonicalization.

This way those one won't need to depend on AffineOps for some simplification of affine expressions.

---
Full diff: https://github.com/llvm/llvm-project/pull/69813.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/AffineExpr.h (+14) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+14-101) 
- (modified) mlir/lib/IR/AffineExpr.cpp (+80) 


``````````diff
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 8ced8770591ee8c..69e02c94ef2708d 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -353,6 +353,20 @@ void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
     e = getAffineSymbolExpr(idx++, ctx);
 }
 
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead a none being returned.
+std::optional<int64_t>
+getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                      ArrayRef<std::optional<int64_t>> constLowerBounds,
+                      ArrayRef<std::optional<int64_t>> constUpperBounds,
+                      bool isUpper);
+
 } // namespace mlir
 
 namespace llvm {
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f2b3171c1ab837b..fe869171ffbd2a6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -700,93 +700,6 @@ static std::optional<int64_t> getUpperBound(Value iv) {
   return forOp.getConstantUpperBound() - 1;
 }
 
-/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
-/// the constant lower and upper bounds for its inputs provided in
-/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
-/// bound can't be computed. This method only handles simple sum of product
-/// expressions (w.r.t constant coefficients) so as to not depend on anything
-/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
-/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
-/// semi-affine ones will lead std::nullopt being returned.
-static std::optional<int64_t>
-getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
-                ArrayRef<std::optional<int64_t>> constLowerBounds,
-                ArrayRef<std::optional<int64_t>> constUpperBounds,
-                bool isUpper) {
-  // Handle divs and mods.
-  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
-    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
-    // can compute an upper bound.
-    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (!rhsConst || rhsConst.getValue() < 1)
-        return std::nullopt;
-      auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                   constLowerBounds, constUpperBounds, isUpper);
-      if (!bound)
-        return std::nullopt;
-      return mlir::floorDiv(*bound, rhsConst.getValue());
-    }
-    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        auto bound =
-            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                            constLowerBounds, constUpperBounds, isUpper);
-        if (!bound)
-          return std::nullopt;
-        return mlir::ceilDiv(*bound, rhsConst.getValue());
-      }
-      return std::nullopt;
-    }
-    if (binOpExpr.getKind() == AffineExprKind::Mod) {
-      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
-      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
-      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        int64_t rhsConstVal = rhsConst.getValue();
-        auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                  constLowerBounds, constUpperBounds,
-                                  /*isUpper=*/false);
-        auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                  constLowerBounds, constUpperBounds, isUpper);
-        if (ub && lb &&
-            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
-          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
-        return isUpper ? rhsConstVal - 1 : 0;
-      }
-    }
-  }
-  // Flatten the expression.
-  SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  flattener.walkPostOrder(expr);
-  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
-  // TODO: Handle local variables. We can get hold of flattener.localExprs and
-  // get bound on the local expr recursively.
-  if (flattener.numLocals > 0)
-    return std::nullopt;
-  int64_t bound = 0;
-  // Substitute the constant lower or upper bound for the dimensional or
-  // symbolic input depending on `isUpper` to determine the bound.
-  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
-    if (flattenedExpr[i] > 0) {
-      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    } else if (flattenedExpr[i] < 0) {
-      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    }
-  }
-  // Constant term.
-  bound += flattenedExpr.back();
-  return bound;
-}
-
 /// Determine a constant upper bound for `expr` if one exists while exploiting
 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
 /// is guaranteed to be less than or equal to it.
@@ -805,9 +718,9 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
     return constExpr.getValue();
 
-  return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
-                         constUpperBounds,
-                         /*isUpper=*/true);
+  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
+                               constUpperBounds,
+                               /*isUpper=*/true);
 }
 
 /// Determine a constant lower bound for `expr` if one exists while exploiting
@@ -829,9 +742,9 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
     lowerBound = constExpr.getValue();
   } else {
-    lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
-                                 constUpperBounds,
-                                 /*isUpper=*/false);
+    lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
+                                       constLowerBounds, constUpperBounds,
+                                       /*isUpper=*/false);
   }
   return lowerBound;
 }
@@ -970,14 +883,14 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
       lowerBounds.push_back(constExpr.getValue());
       upperBounds.push_back(constExpr.getValue());
     } else {
-      lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
-                                            map.getNumSymbols(),
-                                            constLowerBounds, constUpperBounds,
-                                            /*isUpper=*/false));
-      upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
-                                            map.getNumSymbols(),
-                                            constLowerBounds, constUpperBounds,
-                                            /*isUpper=*/true));
+      lowerBounds.push_back(
+          getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
+                                constLowerBounds, constUpperBounds,
+                                /*isUpper=*/false));
+      upperBounds.push_back(
+          getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
+                                constLowerBounds, constUpperBounds,
+                                /*isUpper=*/true));
     }
   }
 
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 7eccbca4e6e7a1a..4b7ec89a842bd65 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1438,3 +1438,83 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
   assert(flattener.operandExprStack.empty());
   return simplifiedExpr;
 }
+
+std::optional<int64_t> mlir::getBoundForAffineExpr(
+    AffineExpr expr, unsigned numDims, unsigned numSymbols,
+    ArrayRef<std::optional<int64_t>> constLowerBounds,
+    ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      auto bound =
+          getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        auto bound =
+            getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                        constLowerBounds, constUpperBounds,
+                                        /*isUpper=*/false);
+        auto ub =
+            getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  flattener.walkPostOrder(expr);
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  // TODO: Handle local variables. We can get hold of flattener.localExprs and
+  // get bound on the local expr recursively.
+  if (flattener.numLocals > 0)
+    return std::nullopt;
+  int64_t bound = 0;
+  // Substitute the constant lower or upper bound for the dimensional or
+  // symbolic input depending on `isUpper` to determine the bound.
+  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+    if (flattenedExpr[i] > 0) {
+      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    } else if (flattenedExpr[i] < 0) {
+      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    }
+  }
+  // Constant term.
+  bound += flattenedExpr.back();
+  return bound;
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/69813


More information about the Mlir-commits mailing list