[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)
Ivan Butygin
llvmlistbot at llvm.org
Fri Jun 7 08:21:43 PDT 2024
================
@@ -24,155 +25,145 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs) {
- return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
- (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ OperationFolder &folder, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return failure();
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+ std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+ if (!maybeConstValue.has_value())
+ return failure();
+
+ Operation *maybeDefiningOp = value.getDefiningOp();
+ Dialect *valueDialect =
+ maybeDefiningOp ? maybeDefiningOp->getDialect()
+ : value.getParentRegion()->getParentOp()->getDialect();
+ Attribute constAttr =
+ rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+ Value constant = folder.getOrCreateConstant(
+ rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+ // Fall back to arith.constant if the dialect materializer doesn't know what
+ // to do with an integer constant.
+ if (!constant)
+ constant = folder.getOrCreateConstant(
+ rewriter.getInsertionBlock(),
+ rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+ value.getType());
+ if (!constant)
+ return failure();
+
+ rewriter.replaceAllUsesWith(value, constant);
+ return success();
}
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return true;
-
- return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().slt(rhs.smin()))
- return true;
-
- if (lhs.smin().sge(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().sle(rhs.smin()))
- return true;
-
- if (lhs.smin().sgt(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ult(rhs.umin()))
- return true;
-
- if (lhs.umin().uge(rhs.umax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ule(rhs.umin()))
- return true;
-
- if (lhs.umin().ugt(rhs.umax()))
- return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ OperationFolder &folder,
+ Operation &op) {
+ bool replacedAll = op.getNumResults() != 0;
+ for (Value res : op.getResults())
+ replacedAll &=
+ succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+ // If all of the results of the operation were replaced, try to erase
+ // the operation completely.
+ if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+ assert(op.use_empty() && "expected all uses to be replaced");
+ rewriter.eraseOp(&op);
+ return success();
+ }
return failure();
}
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ Operation &op) {
+ if (!isa<RemSIOp, RemUIOp>(op))
+ return failure();
+ Value lhs = op.getOperand(0);
+ Value rhs = op.getOperand(1);
+ auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+ if (!rhsConstVal)
+ return failure();
+ int64_t modulus = rhsConstVal.value();
+ if (modulus <= 0)
+ return failure();
+ auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+ if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+ return failure();
+ const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+ const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+ const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+ // The minima and maxima here are given as closed ranges, we must be strictly
+ // less than the modulus.
+ if (min.isNegative() || min.uge(modulus))
+ return failure();
+ if (max.isNegative() || max.uge(modulus))
+ return failure();
+ if (!min.ule(max))
+ return failure();
+
+ // With all those conditions out of the way, we know thas this invocation of
+ // a remainder is a noop because the input is strictly within the range
+ // [0, modulus), so get rid of it.
+ rewriter.replaceOp(&op, ValueRange{lhs});
+ return success();
}
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+ MutableArrayRef<Region> initialRegions) {
+ SmallVector<Block *> worklist;
----------------
Hardcode84 wrote:
I think you can just use `region->walk([](Block*){})` instead of all manual worklist management (`walk` is allowed to delete operations from the callback). Although, it will use recursion inside.
https://github.com/llvm/llvm-project/pull/94712
More information about the Mlir-commits
mailing list