[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)

Ivan Butygin llvmlistbot at llvm.org
Fri Jun 7 14:25:13 PDT 2024


================
@@ -120,52 +85,92 @@ class DataFlowListener : public RewriterBase::Listener {
   DataFlowSolver &s;
 };
 
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased. Also replace any block
+/// arguments with their constant values.
+struct MaterializeKnownConstantValues : public RewritePattern {
+  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
+      : RewritePattern(Pattern::MatchAnyOpTypeTag(), 1, context), solver(s) {}
 
-  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
-      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
-  LogicalResult matchAndRewrite(arith::CmpIOp op,
-                                PatternRewriter &rewriter) const override {
-    auto *lhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
-    if (!lhsResult || lhsResult->getValue().isUninitialized())
+  LogicalResult match(Operation *op) const override {
+    if (matchPattern(op, m_Constant()))
       return failure();
 
-    auto *rhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
-    if (!rhsResult || rhsResult->getValue().isUninitialized())
-      return failure();
+    auto needsReplacing = [&](Value v) {
+      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+    };
+    bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
+    if (op->getNumRegions() == 0)
+      return success(hasConstantResults);
+    bool hasConstantRegionArgs = false;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        hasConstantRegionArgs |=
+            llvm::any_of(block.getArguments(), needsReplacing);
+      }
+    }
+    return success(hasConstantResults || hasConstantRegionArgs);
+  }
 
-    using HandlerFunc =
-        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
-    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
-        handlers{};
-    using Pred = arith::CmpIPredicate;
-    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
-    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
-    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
-    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
-    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
-    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
-    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
-    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
-    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
-    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
-    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
-    if (!handler)
-      return failure();
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    bool replacedAll = (op->getNumResults() != 0);
+    for (Value v : op->getResults())
+      replacedAll &= succeeded(maybeReplaceWithConstant(solver, rewriter, v));
+    if (replacedAll && isOpTriviallyDead(op)) {
+      rewriter.eraseOp(op);
+      return;
+    }
+
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        for (BlockArgument &arg : block.getArguments()) {
+          PatternRewriter::InsertionGuard guard(rewriter);
+          rewriter.setInsertionPointToStart(&block);
+          (void)maybeReplaceWithConstant(solver, rewriter, arg);
+        }
+      }
+    }
+  }
+
+private:
+  DataFlowSolver &solver;
+};
 
-    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
-    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
-    FailureOr<bool> result = handler(lhsValue, rhsValue);
+template <typename RemOp>
+struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
+  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<RemOp>(context), solver(s) {}
 
-    if (failed(result))
+  LogicalResult matchAndRewrite(RemOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.getOperand(0);
+    Value rhs = op.getOperand(1);
+    auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
----------------
Hardcode84 wrote:

`getConstantIntValue`?

https://github.com/llvm/llvm-project/pull/94712


More information about the Mlir-commits mailing list