[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)

Ivan Butygin llvmlistbot at llvm.org
Fri Jun 7 08:21:43 PDT 2024


================
@@ -24,155 +25,145 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::dataflow;
 
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
-                       const ConstantIntRanges &rhs) {
-  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
-           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+                                         RewriterBase &rewriter,
+                                         OperationFolder &folder, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.has_value())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr =
+      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(
+      rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+  // Fall back to arith.constant if the dialect materializer doesn't know what
+  // to do with an integer constant.
+  if (!constant)
+    constant = folder.getOrCreateConstant(
+        rewriter.getInsertionBlock(),
+        rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+        value.getType());
+  if (!constant)
+    return failure();
+
+  rewriter.replaceAllUsesWith(value, constant);
+  return success();
 }
 
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return true;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().slt(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sge(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().sle(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sgt(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ult(rhs.umin()))
-    return true;
-
-  if (lhs.umin().uge(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ule(rhs.umin()))
-    return true;
-
-  if (lhs.umin().ugt(rhs.umax()))
-    return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            OperationFolder &folder,
+                                            Operation &op) {
+  bool replacedAll = op.getNumResults() != 0;
+  for (Value res : op.getResults())
+    replacedAll &=
+        succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+  // If all of the results of the operation were replaced, try to erase
+  // the operation completely.
+  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+    assert(op.use_empty() && "expected all uses to be replaced");
+    rewriter.eraseOp(&op);
+    return success();
+  }
   return failure();
 }
 
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            Operation &op) {
+  if (!isa<RemSIOp, RemUIOp>(op))
+    return failure();
+  Value lhs = op.getOperand(0);
+  Value rhs = op.getOperand(1);
+  auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+  if (!rhsConstVal)
+    return failure();
+  int64_t modulus = rhsConstVal.value();
+  if (modulus <= 0)
+    return failure();
+  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+  const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+  const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+  // The minima and maxima here are given as closed ranges, we must be strictly
+  // less than the modulus.
+  if (min.isNegative() || min.uge(modulus))
+    return failure();
+  if (max.isNegative() || max.uge(modulus))
+    return failure();
+  if (!min.ule(max))
+    return failure();
+
+  // With all those conditions out of the way, we know thas this invocation of
+  // a remainder is a noop because the input is strictly within the range
+  // [0, modulus), so get rid of it.
+  rewriter.replaceOp(&op, ValueRange{lhs});
+  return success();
 }
 
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+                       MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
----------------
Hardcode84 wrote:

I think you can just use `region->walk([](Block*){})`  instead of all manual worklist management (`walk` is allowed to delete operations from the callback). Although, it will use recursion inside.

https://github.com/llvm/llvm-project/pull/94712


More information about the Mlir-commits mailing list