[Mlir-commits] [mlir] [mlir][affine] remove divide zero check when simplifer affineMap (#64622) (PR #68519)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 3 12:22:35 PDT 2023
================
@@ -700,6 +708,100 @@ static std::optional<int64_t> getUpperBound(Value iv) {
return forOp.getConstantUpperBound() - 1;
}
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
+/// bound can't be computed. This method only handles simple sum of product
+/// expressions (w.r.t constant coefficients) so as to not depend on anything
+/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
+/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
+/// semi-affine ones will lead std::nullopt being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+ ArrayRef<std::optional<int64_t>> constLowerBounds,
+ ArrayRef<std::optional<int64_t>> constUpperBounds,
+ bool isUpper) {
+ // Handle divs and mods.
+ if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+ // can compute an upper bound.
+ if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst || rhsConst.getValue() < 1)
+ return std::nullopt;
+ std::optional<int64_t> bound =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::floorDiv(*bound, rhsConst.getValue());
+ }
+ if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ std::optional<int64_t> bound =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::ceilDiv(*bound, rhsConst.getValue());
+ }
+ return std::nullopt;
+ }
+ if (binOpExpr.getKind() == AffineExprKind::Mod) {
+ // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+ // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+ // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ int64_t rhsConstVal = rhsConst.getValue();
+ std::optional<int64_t> lb =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/false);
+ std::optional<int64_t> ub =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (ub && lb &&
+ floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+ return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+ return isUpper ? rhsConstVal - 1 : 0;
+ }
+ }
+ }
+
+ // Flatten the expression.
+ SimpleAffineExprFlattener flattener(numDims, numSymbols);
+ auto flattenResult = flattener.walkPostOrder(expr);
+ // has poison expression
+ if (failed(flattenResult))
+ return std::nullopt;
+ ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+ // TODO: Handle local variables. We can get hold of flattener.localExprs and
+ // get bound on the local expr recursively.
+ if (flattener.numLocals > 0)
+ return std::nullopt;
+ int64_t bound = 0;
+ // Substitute the constant lower or upper bound for the dimensional or
+ // symbolic input depending on `isUpper` to determine the bound.
+ for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+ if (flattenedExpr[i] > 0) {
+ auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+ if (!constBound)
+ return std::nullopt;
+ bound += *constBound * flattenedExpr[i];
+ } else if (flattenedExpr[i] < 0) {
+ auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
----------------
kuhar wrote:
Use the full type if not obvious based on the RHS
https://github.com/llvm/llvm-project/pull/68519
More information about the Mlir-commits
mailing list