[Mlir-commits] [mlir] 4506de1 - NFC. Move out and expose affine expression simplification utility out of AffineOps lib (#69813)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 01:04:57 PDT 2023


Author: Uday Bondhugula
Date: 2023-10-26T13:34:54+05:30
New Revision: 4506de1e2bfe9352331d75fe6d3c89f3d2f5287c

URL: https://github.com/llvm/llvm-project/commit/4506de1e2bfe9352331d75fe6d3c89f3d2f5287c
DIFF: https://github.com/llvm/llvm-project/commit/4506de1e2bfe9352331d75fe6d3c89f3d2f5287c.diff

LOG: NFC. Move out and expose affine expression simplification utility out of AffineOps lib (#69813)

Move out trivial affine expression simplification out of AffineOps
library. Expose it from libIR. Users of such methods shouldn't have to
rely on the AffineOps dialect. For eg., with this change, the method can
be used now from lib/Analysis/ (FlatLinearConstraints) as well as
AffineOps dialect canonicalization.

This way those one won't need to depend on AffineOps for some
simplification of affine expressions.

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/IR/AffineExpr.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 8ced8770591ee8c..69e02c94ef2708d 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -353,6 +353,20 @@ void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
     e = getAffineSymbolExpr(idx++, ctx);
 }
 
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead a none being returned.
+std::optional<int64_t>
+getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                      ArrayRef<std::optional<int64_t>> constLowerBounds,
+                      ArrayRef<std::optional<int64_t>> constUpperBounds,
+                      bool isUpper);
+
 } // namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4d79c458889d2a9..ba4285bd52394f3 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -700,93 +700,6 @@ static std::optional<int64_t> getUpperBound(Value iv) {
   return forOp.getConstantUpperBound() - 1;
 }
 
-/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
-/// the constant lower and upper bounds for its inputs provided in
-/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
-/// bound can't be computed. This method only handles simple sum of product
-/// expressions (w.r.t constant coefficients) so as to not depend on anything
-/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
-/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
-/// semi-affine ones will lead std::nullopt being returned.
-static std::optional<int64_t>
-getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
-                ArrayRef<std::optional<int64_t>> constLowerBounds,
-                ArrayRef<std::optional<int64_t>> constUpperBounds,
-                bool isUpper) {
-  // Handle divs and mods.
-  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
-    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
-    // can compute an upper bound.
-    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (!rhsConst || rhsConst.getValue() < 1)
-        return std::nullopt;
-      auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                   constLowerBounds, constUpperBounds, isUpper);
-      if (!bound)
-        return std::nullopt;
-      return mlir::floorDiv(*bound, rhsConst.getValue());
-    }
-    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        auto bound =
-            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                            constLowerBounds, constUpperBounds, isUpper);
-        if (!bound)
-          return std::nullopt;
-        return mlir::ceilDiv(*bound, rhsConst.getValue());
-      }
-      return std::nullopt;
-    }
-    if (binOpExpr.getKind() == AffineExprKind::Mod) {
-      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
-      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
-      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
-      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
-      if (rhsConst && rhsConst.getValue() >= 1) {
-        int64_t rhsConstVal = rhsConst.getValue();
-        auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                  constLowerBounds, constUpperBounds,
-                                  /*isUpper=*/false);
-        auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
-                                  constLowerBounds, constUpperBounds, isUpper);
-        if (ub && lb &&
-            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
-          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
-        return isUpper ? rhsConstVal - 1 : 0;
-      }
-    }
-  }
-  // Flatten the expression.
-  SimpleAffineExprFlattener flattener(numDims, numSymbols);
-  flattener.walkPostOrder(expr);
-  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
-  // TODO: Handle local variables. We can get hold of flattener.localExprs and
-  // get bound on the local expr recursively.
-  if (flattener.numLocals > 0)
-    return std::nullopt;
-  int64_t bound = 0;
-  // Substitute the constant lower or upper bound for the dimensional or
-  // symbolic input depending on `isUpper` to determine the bound.
-  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
-    if (flattenedExpr[i] > 0) {
-      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    } else if (flattenedExpr[i] < 0) {
-      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
-      if (!constBound)
-        return std::nullopt;
-      bound += *constBound * flattenedExpr[i];
-    }
-  }
-  // Constant term.
-  bound += flattenedExpr.back();
-  return bound;
-}
-
 /// Determine a constant upper bound for `expr` if one exists while exploiting
 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
 /// is guaranteed to be less than or equal to it.
@@ -805,9 +718,9 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
     return constExpr.getValue();
 
-  return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
-                         constUpperBounds,
-                         /*isUpper=*/true);
+  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
+                               constUpperBounds,
+                               /*isUpper=*/true);
 }
 
 /// Determine a constant lower bound for `expr` if one exists while exploiting
@@ -829,9 +742,9 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
   if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
     lowerBound = constExpr.getValue();
   } else {
-    lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
-                                 constUpperBounds,
-                                 /*isUpper=*/false);
+    lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
+                                       constLowerBounds, constUpperBounds,
+                                       /*isUpper=*/false);
   }
   return lowerBound;
 }
@@ -970,14 +883,14 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
       lowerBounds.push_back(constExpr.getValue());
       upperBounds.push_back(constExpr.getValue());
     } else {
-      lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
-                                            map.getNumSymbols(),
-                                            constLowerBounds, constUpperBounds,
-                                            /*isUpper=*/false));
-      upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
-                                            map.getNumSymbols(),
-                                            constLowerBounds, constUpperBounds,
-                                            /*isUpper=*/true));
+      lowerBounds.push_back(
+          getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
+                                constLowerBounds, constUpperBounds,
+                                /*isUpper=*/false));
+      upperBounds.push_back(
+          getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
+                                constLowerBounds, constUpperBounds,
+                                /*isUpper=*/true));
     }
   }
 

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 7eccbca4e6e7a1a..4b7ec89a842bd65 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1438,3 +1438,83 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
   assert(flattener.operandExprStack.empty());
   return simplifiedExpr;
 }
+
+std::optional<int64_t> mlir::getBoundForAffineExpr(
+    AffineExpr expr, unsigned numDims, unsigned numSymbols,
+    ArrayRef<std::optional<int64_t>> constLowerBounds,
+    ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      auto bound =
+          getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        auto bound =
+            getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                        constLowerBounds, constUpperBounds,
+                                        /*isUpper=*/false);
+        auto ub =
+            getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  flattener.walkPostOrder(expr);
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  // TODO: Handle local variables. We can get hold of flattener.localExprs and
+  // get bound on the local expr recursively.
+  if (flattener.numLocals > 0)
+    return std::nullopt;
+  int64_t bound = 0;
+  // Substitute the constant lower or upper bound for the dimensional or
+  // symbolic input depending on `isUpper` to determine the bound.
+  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+    if (flattenedExpr[i] > 0) {
+      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    } else if (flattenedExpr[i] < 0) {
+      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    }
+  }
+  // Constant term.
+  bound += flattenedExpr.back();
+  return bound;
+}


        


More information about the Mlir-commits mailing list