[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)
Ivan Butygin
llvmlistbot at llvm.org
Fri Jun 7 14:25:13 PDT 2024
================
@@ -120,52 +85,92 @@ class DataFlowListener : public RewriterBase::Listener {
DataFlowSolver &s;
};
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased. Also replace any block
+/// arguments with their constant values.
+struct MaterializeKnownConstantValues : public RewritePattern {
+ MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
+ : RewritePattern(Pattern::MatchAnyOpTypeTag(), 1, context), solver(s) {}
- ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
- : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
- LogicalResult matchAndRewrite(arith::CmpIOp op,
- PatternRewriter &rewriter) const override {
- auto *lhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
- if (!lhsResult || lhsResult->getValue().isUninitialized())
+ LogicalResult match(Operation *op) const override {
+ if (matchPattern(op, m_Constant()))
return failure();
- auto *rhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
- if (!rhsResult || rhsResult->getValue().isUninitialized())
- return failure();
+ auto needsReplacing = [&](Value v) {
+ return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+ };
+ bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
+ if (op->getNumRegions() == 0)
+ return success(hasConstantResults);
+ bool hasConstantRegionArgs = false;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ hasConstantRegionArgs |=
+ llvm::any_of(block.getArguments(), needsReplacing);
+ }
+ }
+ return success(hasConstantResults || hasConstantRegionArgs);
+ }
- using HandlerFunc =
- FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
- std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
- handlers{};
- using Pred = arith::CmpIPredicate;
- handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
- handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
- handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
- handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
- handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
- handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
- handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
- handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
- handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
- handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
- HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
- if (!handler)
- return failure();
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ bool replacedAll = (op->getNumResults() != 0);
+ for (Value v : op->getResults())
+ replacedAll &= succeeded(maybeReplaceWithConstant(solver, rewriter, v));
+ if (replacedAll && isOpTriviallyDead(op)) {
+ rewriter.eraseOp(op);
+ return;
+ }
+
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ for (BlockArgument &arg : block.getArguments()) {
+ PatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&block);
+ (void)maybeReplaceWithConstant(solver, rewriter, arg);
+ }
+ }
+ }
+ }
+
+private:
+ DataFlowSolver &solver;
+};
- ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
- ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
- FailureOr<bool> result = handler(lhsValue, rhsValue);
+template <typename RemOp>
+struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
+ DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<RemOp>(context), solver(s) {}
- if (failed(result))
+ LogicalResult matchAndRewrite(RemOp op,
+ PatternRewriter &rewriter) const override {
+ Value lhs = op.getOperand(0);
+ Value rhs = op.getOperand(1);
+ auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
----------------
Hardcode84 wrote:
`getConstantIntValue`?
https://github.com/llvm/llvm-project/pull/94712
More information about the Mlir-commits
mailing list