[Mlir-commits] [mlir] 6005a1d - [mlir][scf] Match any constants instead of arith.constant
Jeff Niu
llvmlistbot at llvm.org
Wed Oct 12 18:02:05 PDT 2022
Author: Jeff Niu
Date: 2022-10-12T18:01:57-07:00
New Revision: 6005a1d8af32608a0f2ddddb0eaddbe39ec3271b
URL: https://github.com/llvm/llvm-project/commit/6005a1d8af32608a0f2ddddb0eaddbe39ec3271b
DIFF: https://github.com/llvm/llvm-project/commit/6005a1d8af32608a0f2ddddb0eaddbe39ec3271b.diff
LOG: [mlir][scf] Match any constants instead of arith.constant
By matching `arith.constant` specifically, SCF canonicalizers/folders
are incompatible with other kinds of constants. Use the generic
matchers instead.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D135517
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 071814fcbed85..b3e84f0d6271a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -299,9 +299,9 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
}
LogicalResult ForOp::verify() {
- if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
- if (cst.value() <= 0)
- return emitOpError("constant step operand must be positive");
+ IntegerAttr step;
+ if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
+ return emitOpError("constant step operand must be positive");
auto opNumResults = getNumResults();
if (opNumResults == 0)
@@ -719,11 +719,10 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Returns llvm::None when the
diff erence between two AffineValueMap is
/// dynamic.
static Optional<int64_t> computeConstDiff(Value l, Value u) {
- auto clb = l.getDefiningOp<arith::ConstantOp>();
- auto cub = u.getDefiningOp<arith::ConstantOp>();
- if (cub && clb) {
- llvm::APInt lbValue = clb.getValue().cast<IntegerAttr>().getValue();
- llvm::APInt ubValue = cub.getValue().cast<IntegerAttr>().getValue();
+ IntegerAttr clb, cub;
+ if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
+ llvm::APInt lbValue = clb.getValue();
+ llvm::APInt ubValue = cub.getValue();
return (ubValue - lbValue).getSExtValue();
}
@@ -763,13 +762,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
- if (!step)
+ IntegerAttr step;
+ if (!matchPattern(op.getStep(), m_Constant(&step)))
return failure();
// If the loop is known to have 1 iteration, inline its body and remove the
// loop.
- llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
+ llvm::APInt stepValue = step.getValue();
if (stepValue.sge(*
diff )) {
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getNumIterOperands() + 1);
@@ -1674,11 +1673,11 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
- auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
- if (!constant)
+ BoolAttr condition;
+ if (!matchPattern(op.getCondition(), m_Constant(&condition)))
return failure();
- if (constant.getValue().cast<BoolAttr>().getValue())
+ if (condition.getValue())
replaceOpWithRegion(rewriter, op, op.getThenRegion());
else if (!op.getElseRegion().empty())
replaceOpWithRegion(rewriter, op, op.getElseRegion());
@@ -1777,7 +1776,7 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
// in the body with another constant isn't a simplification.
- if (op.getCondition().getDefiningOp<arith::ConstantOp>())
+ if (matchPattern(op.getCondition(), m_Constant()))
return failure();
bool changed = false;
@@ -1881,25 +1880,23 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
continue;
}
- auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
- if (!trueYield)
+ BoolAttr trueYield, falseYield;
+ if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
+ !matchPattern(falseResult, m_Constant(&falseYield)))
continue;
- if (!trueYield.getType().isInteger(1))
- continue;
-
- auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
- if (!falseYield)
- continue;
-
- bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
- bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
+ bool trueVal = trueYield.getValue();
+ bool falseVal = falseYield.getValue();
if (!trueVal && falseVal) {
if (!opResult.use_empty()) {
+ Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
Value notCond = rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
- rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
+ constDialect
+ ->materializeConstant(rewriter,
+ rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
+ op.getLoc())
+ ->getResult(0));
opResult.replaceAllUsesWith(notCond);
changed = true;
}
More information about the Mlir-commits
mailing list