[Mlir-commits] [mlir] 6005a1d - [mlir][scf] Match any constants instead of arith.constant

Jeff Niu llvmlistbot at llvm.org
Wed Oct 12 18:02:05 PDT 2022


Author: Jeff Niu
Date: 2022-10-12T18:01:57-07:00
New Revision: 6005a1d8af32608a0f2ddddb0eaddbe39ec3271b

URL: https://github.com/llvm/llvm-project/commit/6005a1d8af32608a0f2ddddb0eaddbe39ec3271b
DIFF: https://github.com/llvm/llvm-project/commit/6005a1d8af32608a0f2ddddb0eaddbe39ec3271b.diff

LOG: [mlir][scf] Match any constants instead of arith.constant

By matching `arith.constant` specifically, SCF canonicalizers/folders
are incompatible with other kinds of constants. Use the generic
matchers instead.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135517

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 071814fcbed85..b3e84f0d6271a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -299,9 +299,9 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
 }
 
 LogicalResult ForOp::verify() {
-  if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
-    if (cst.value() <= 0)
-      return emitOpError("constant step operand must be positive");
+  IntegerAttr step;
+  if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
+    return emitOpError("constant step operand must be positive");
 
   auto opNumResults = getNumResults();
   if (opNumResults == 0)
@@ -719,11 +719,10 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
 /// Returns llvm::None when the 
diff erence between two AffineValueMap is
 /// dynamic.
 static Optional<int64_t> computeConstDiff(Value l, Value u) {
-  auto clb = l.getDefiningOp<arith::ConstantOp>();
-  auto cub = u.getDefiningOp<arith::ConstantOp>();
-  if (cub && clb) {
-    llvm::APInt lbValue = clb.getValue().cast<IntegerAttr>().getValue();
-    llvm::APInt ubValue = cub.getValue().cast<IntegerAttr>().getValue();
+  IntegerAttr clb, cub;
+  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
+    llvm::APInt lbValue = clb.getValue();
+    llvm::APInt ubValue = cub.getValue();
     return (ubValue - lbValue).getSExtValue();
   }
 
@@ -763,13 +762,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       return success();
     }
 
-    auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
-    if (!step)
+    IntegerAttr step;
+    if (!matchPattern(op.getStep(), m_Constant(&step)))
       return failure();
 
     // If the loop is known to have 1 iteration, inline its body and remove the
     // loop.
-    llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
+    llvm::APInt stepValue = step.getValue();
     if (stepValue.sge(*
diff )) {
       SmallVector<Value, 4> blockArgs;
       blockArgs.reserve(op.getNumIterOperands() + 1);
@@ -1674,11 +1673,11 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
 
   LogicalResult matchAndRewrite(IfOp op,
                                 PatternRewriter &rewriter) const override {
-    auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
-    if (!constant)
+    BoolAttr condition;
+    if (!matchPattern(op.getCondition(), m_Constant(&condition)))
       return failure();
 
-    if (constant.getValue().cast<BoolAttr>().getValue())
+    if (condition.getValue())
       replaceOpWithRegion(rewriter, op, op.getThenRegion());
     else if (!op.getElseRegion().empty())
       replaceOpWithRegion(rewriter, op, op.getElseRegion());
@@ -1777,7 +1776,7 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
                                 PatternRewriter &rewriter) const override {
     // Early exit if the condition is constant since replacing a constant
     // in the body with another constant isn't a simplification.
-    if (op.getCondition().getDefiningOp<arith::ConstantOp>())
+    if (matchPattern(op.getCondition(), m_Constant()))
       return failure();
 
     bool changed = false;
@@ -1881,25 +1880,23 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
         continue;
       }
 
-      auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
-      if (!trueYield)
+      BoolAttr trueYield, falseYield;
+      if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
+          !matchPattern(falseResult, m_Constant(&falseYield)))
         continue;
 
-      if (!trueYield.getType().isInteger(1))
-        continue;
-
-      auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
-      if (!falseYield)
-        continue;
-
-      bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
-      bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
+      bool trueVal = trueYield.getValue();
+      bool falseVal = falseYield.getValue();
       if (!trueVal && falseVal) {
         if (!opResult.use_empty()) {
+          Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
           Value notCond = rewriter.create<arith::XOrIOp>(
               op.getLoc(), op.getCondition(),
-              rewriter.create<arith::ConstantOp>(
-                  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
+              constDialect
+                  ->materializeConstant(rewriter,
+                                        rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
+                                        op.getLoc())
+                  ->getResult(0));
           opResult.replaceAllUsesWith(notCond);
           changed = true;
         }


        


More information about the Mlir-commits mailing list