[Mlir-commits] [mlir] 61f3775 - [mlir][SCF] Fix incorrect API usage in RewritePatterns

Matthias Springer llvmlistbot at llvm.org
Mon Feb 27 00:43:20 PST 2023


Author: Matthias Springer
Date: 2023-02-27T09:36:14+01:00
New Revision: 61f37758048c14ed13b4545cbdf1f0d12496c237

URL: https://github.com/llvm/llvm-project/commit/61f37758048c14ed13b4545cbdf1f0d12496c237
DIFF: https://github.com/llvm/llvm-project/commit/61f37758048c14ed13b4545cbdf1f0d12496c237.diff

LOG: [mlir][SCF] Fix incorrect API usage in RewritePatterns

Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D144636

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index fce6633bb7105..acc7e3ef8e1dc 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -546,6 +546,11 @@ class RewriterBase : public OpBuilder {
       updateRootInPlace(op, [&]() { operand.set(to); });
     }
   }
+  void replaceAllUsesWith(ValueRange from, ValueRange to) {
+    assert(from.size() == to.size() && "incorrect number of replacements");
+    for (auto it : llvm::zip(from, to))
+      replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. It also marks every modified uses and notifies the rewriter that an

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f8fd2016cd9a9..84caca9806e22 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1441,21 +1441,23 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
         failed(foldDynamicIndexList(rewriter, mixedStep)))
       return failure();
 
-    SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
-    SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
-    dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
-                               staticLowerBound);
-    op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
-    op.setStaticLowerBound(staticLowerBound);
-
-    dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
-                               staticUpperBound);
-    op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
-    op.setStaticUpperBound(staticUpperBound);
-
-    dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
-    op.getDynamicStepMutable().assign(dynamicStep);
-    op.setStaticStep(staticStep);
+    rewriter.updateRootInPlace(op, [&]() {
+      SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
+      SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
+      dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
+                                 staticLowerBound);
+      op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
+      op.setStaticLowerBound(staticLowerBound);
+
+      dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
+                                 staticUpperBound);
+      op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
+      op.setStaticUpperBound(staticUpperBound);
+
+      dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
+      op.getDynamicStepMutable().assign(dynamicStep);
+      op.setStaticStep(staticStep);
+    });
     return success();
   }
 };
@@ -3073,7 +3075,8 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
                 op.getLoc(), term.getCondition().getType(),
                 rewriter.getBoolAttr(true));
 
-          std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+          rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
+                                      constantTrue);
           replaced = true;
         }
       }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 5a9e0217edbda..8863b0833d3e7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -78,7 +78,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     // Rewrite uses of the for-loop block arguments to the new while-loop
     // "after" arguments
     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
-      barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
+      rewriter.replaceAllUsesWith(barg.value(),
+                                  afterBlock->getArgument(barg.index()));
 
     // Inline for-loop body operations into 'after' region.
     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
@@ -88,7 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
       SmallVector<Value> yieldOperands = yieldOp.getOperands();
       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
-      yieldOp->setOperands(yieldOperands);
+      rewriter.updateRootInPlace(
+          yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
     }
 
     // We cannot do a direct replacement of the forOp since the while op returns
@@ -96,7 +98,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     // carried in the set of iterargs). Instead, rewrite uses of the forOp
     // results.
     for (const auto &arg : llvm::enumerate(forOp.getResults()))
-      arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
+      rewriter.replaceAllUsesWith(arg.value(),
+                                  whileOp.getResult(arg.index() + 1));
 
     rewriter.eraseOp(forOp);
     return success();

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 82c2223b6eec0..3ee440da66dcd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -144,7 +144,7 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
   b.setInsertionPointAfter(forOp);
   partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
   partialIteration.getLowerBoundMutable().assign(splitBound);
-  forOp.replaceAllUsesWith(partialIteration->getResults());
+  b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults());
   partialIteration.getInitArgsMutable().assign(forOp->getResults());
 
   // Set new upper loop bound.
@@ -221,11 +221,13 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
     if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration)))
       return failure();
     // Apply label, so that the same loop is not rewritten a second time.
-    partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+    rewriter.updateRootInPlace(partialIteration, [&]() {
+      partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+      partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+    });
     rewriter.updateRootInPlace(forOp, [&]() {
       forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
     });
-    partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
     return success();
   }
 

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d3dfd16ba0442..a3ce8a63d4c9f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -302,6 +302,7 @@ func.func @to_select_same_val(%cond: i1) -> (index, index) {
 // CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
 // CHECK:           return [[V0]], [[C1]] : index, index
 
+// -----
 
 func.func @to_select_with_body(%cond: i1) -> index {
   %c0 = arith.constant 0 : index
@@ -323,6 +324,7 @@ func.func @to_select_with_body(%cond: i1) -> index {
 // CHECK:             "test.op"() : () -> ()
 // CHECK:           }
 // CHECK:           return [[V0]] : index
+
 // -----
 
 func.func @to_select2(%cond: i1) -> (index, index) {
@@ -363,6 +365,10 @@ func.func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
 //  CHECK-NEXT:     %[[R:.*]] = call @make_i32() : () -> i32
 //  CHECK-NEXT:     return %[[R]] : i32
 
+// -----
+
+func.func private @make_i32() -> i32
+
 func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
   %a = call @make_i32() : () -> (i32)
   %b = call @make_i32() : () -> (i32)
@@ -523,6 +529,8 @@ func.func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8)
   return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
 }
 
+// -----
+
 // CHECK-LABEL: @merge_yielding_nested_if_nv1
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
@@ -547,6 +555,8 @@ func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @merge_yielding_nested_if_nv2
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
@@ -571,6 +581,8 @@ func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
   return %r : i32
 }
 
+// -----
+
 // CHECK-LABEL: @merge_fail_yielding_nested_if
 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
 func.func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
@@ -1125,6 +1137,8 @@ func.func @while_unused_result() -> i32 {
 // CHECK-NEXT:         }
 // CHECK-NEXT:         return %[[res]] : i32
 
+// -----
+
 // CHECK-LABEL: @while_cmp_lhs
 func.func @while_cmp_lhs(%arg0 : i32) {
   %0 = scf.while () : () -> i32 {
@@ -1152,6 +1166,8 @@ func.func @while_cmp_lhs(%arg0 : i32) {
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
 
+// -----
+
 // CHECK-LABEL: @while_cmp_rhs
 func.func @while_cmp_rhs(%arg0 : i32) {
   %0 = scf.while () : () -> i32 {
@@ -1210,6 +1226,7 @@ func.func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
 
+// -----
 
 // CHECK-LABEL: @combineIfs2
 func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1236,6 +1253,7 @@ func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]] : i32
 
+// -----
 
 // CHECK-LABEL: @combineIfs3
 func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1262,6 +1280,8 @@ func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]] : i32
 
+// -----
+
 // CHECK-LABEL: @combineIfs4
 func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
   scf.if %arg0 {
@@ -1280,6 +1300,8 @@ func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
 
+// -----
+
 // CHECK-LABEL: @combineIfsUsed
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
@@ -1310,6 +1332,8 @@ func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
 
+// -----
+
 // CHECK-LABEL: @combineIfsNot
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
@@ -1332,6 +1356,8 @@ func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
 
+// -----
+
 // CHECK-LABEL: @combineIfsNot2
 // CHECK-SAME: %[[arg0:.+]]: i1
 func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
@@ -1353,6 +1379,7 @@ func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:     } else {
 // CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
+
 // -----
 
 // CHECK-LABEL: func @propagate_into_execute_region
@@ -1403,7 +1430,6 @@ func.func @execute_region_elim() {
 // CHECK-NEXT:       "test.bar"(%[[VAL]]) : (i64) -> ()
 // CHECK-NEXT:     }
 
-
 // -----
 
 // CHECK-LABEL: func @func_execute_region_elim
@@ -1439,7 +1465,6 @@ func.func @func_execute_region_elim() {
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
 
-
 // -----
 
 // CHECK-LABEL: func @func_execute_region_elim_multi_yield


        


More information about the Mlir-commits mailing list