[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) (PR #68519)

Jakub Kuderski llvmlistbot at llvm.org
Mon Oct 30 07:57:33 PDT 2023


================
@@ -700,6 +708,94 @@ static std::optional<int64_t> getUpperBound(Value iv) {
   return forOp.getConstantUpperBound() - 1;
 }
 
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead std::nullopt being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                ArrayRef<std::optional<int64_t>> constLowerBounds,
+                ArrayRef<std::optional<int64_t>> constUpperBounds,
+                bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                   constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        auto bound =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds,
+                                  /*isUpper=*/false);
+        auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  auto flattenResult = flattener.walkPostOrder(expr);
+  assert(succeeded(flattenResult) && "affine expr containts poison expr");
----------------
kuhar wrote:

same here

https://github.com/llvm/llvm-project/pull/68519


More information about the Mlir-commits mailing list