[Mlir-commits] [mlir] 61f3775 - [mlir][SCF] Fix incorrect API usage in RewritePatterns
Matthias Springer
llvmlistbot at llvm.org
Mon Feb 27 00:43:20 PST 2023
Author: Matthias Springer
Date: 2023-02-27T09:36:14+01:00
New Revision: 61f37758048c14ed13b4545cbdf1f0d12496c237
URL: https://github.com/llvm/llvm-project/commit/61f37758048c14ed13b4545cbdf1f0d12496c237
DIFF: https://github.com/llvm/llvm-project/commit/61f37758048c14ed13b4545cbdf1f0d12496c237.diff
LOG: [mlir][SCF] Fix incorrect API usage in RewritePatterns
Incorrect API usage was detected by D144552.
Differential Revision: https://reviews.llvm.org/D144636
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index fce6633bb7105..acc7e3ef8e1dc 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -546,6 +546,11 @@ class RewriterBase : public OpBuilder {
updateRootInPlace(op, [&]() { operand.set(to); });
}
}
+ void replaceAllUsesWith(ValueRange from, ValueRange to) {
+ assert(from.size() == to.size() && "incorrect number of replacements");
+ for (auto it : llvm::zip(from, to))
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ }
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f8fd2016cd9a9..84caca9806e22 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1441,21 +1441,23 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
failed(foldDynamicIndexList(rewriter, mixedStep)))
return failure();
- SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
- SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
- dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
- staticLowerBound);
- op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
- op.setStaticLowerBound(staticLowerBound);
-
- dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
- staticUpperBound);
- op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
- op.setStaticUpperBound(staticUpperBound);
-
- dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
- op.getDynamicStepMutable().assign(dynamicStep);
- op.setStaticStep(staticStep);
+ rewriter.updateRootInPlace(op, [&]() {
+ SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
+ SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
+ dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
+ staticLowerBound);
+ op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
+ op.setStaticLowerBound(staticLowerBound);
+
+ dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
+ staticUpperBound);
+ op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
+ op.setStaticUpperBound(staticUpperBound);
+
+ dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
+ op.getDynamicStepMutable().assign(dynamicStep);
+ op.setStaticStep(staticStep);
+ });
return success();
}
};
@@ -3073,7 +3075,8 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
op.getLoc(), term.getCondition().getType(),
rewriter.getBoolAttr(true));
- std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
+ rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
+ constantTrue);
replaced = true;
}
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 5a9e0217edbda..8863b0833d3e7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -78,7 +78,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// Rewrite uses of the for-loop block arguments to the new while-loop
// "after" arguments
for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
- barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
+ rewriter.replaceAllUsesWith(barg.value(),
+ afterBlock->getArgument(barg.index()));
// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
@@ -88,7 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
SmallVector<Value> yieldOperands = yieldOp.getOperands();
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
- yieldOp->setOperands(yieldOperands);
+ rewriter.updateRootInPlace(
+ yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
}
// We cannot do a direct replacement of the forOp since the while op returns
@@ -96,7 +98,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// carried in the set of iterargs). Instead, rewrite uses of the forOp
// results.
for (const auto &arg : llvm::enumerate(forOp.getResults()))
- arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
+ rewriter.replaceAllUsesWith(arg.value(),
+ whileOp.getResult(arg.index() + 1));
rewriter.eraseOp(forOp);
return success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 82c2223b6eec0..3ee440da66dcd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -144,7 +144,7 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
b.setInsertionPointAfter(forOp);
partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
partialIteration.getLowerBoundMutable().assign(splitBound);
- forOp.replaceAllUsesWith(partialIteration->getResults());
+ b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults());
partialIteration.getInitArgsMutable().assign(forOp->getResults());
// Set new upper loop bound.
@@ -221,11 +221,13 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration)))
return failure();
// Apply label, so that the same loop is not rewritten a second time.
- partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+ rewriter.updateRootInPlace(partialIteration, [&]() {
+ partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+ partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+ });
rewriter.updateRootInPlace(forOp, [&]() {
forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
});
- partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
return success();
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d3dfd16ba0442..a3ce8a63d4c9f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -302,6 +302,7 @@ func.func @to_select_same_val(%cond: i1) -> (index, index) {
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
// CHECK: return [[V0]], [[C1]] : index, index
+// -----
func.func @to_select_with_body(%cond: i1) -> index {
%c0 = arith.constant 0 : index
@@ -323,6 +324,7 @@ func.func @to_select_with_body(%cond: i1) -> index {
// CHECK: "test.op"() : () -> ()
// CHECK: }
// CHECK: return [[V0]] : index
+
// -----
func.func @to_select2(%cond: i1) -> (index, index) {
@@ -363,6 +365,10 @@ func.func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
// CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32
// CHECK-NEXT: return %[[R]] : i32
+// -----
+
+func.func private @make_i32() -> i32
+
func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
%a = call @make_i32() : () -> (i32)
%b = call @make_i32() : () -> (i32)
@@ -523,6 +529,8 @@ func.func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8)
return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
}
+// -----
+
// CHECK-LABEL: @merge_yielding_nested_if_nv1
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
@@ -547,6 +555,8 @@ func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @merge_yielding_nested_if_nv2
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
@@ -571,6 +581,8 @@ func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
return %r : i32
}
+// -----
+
// CHECK-LABEL: @merge_fail_yielding_nested_if
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
func.func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
@@ -1125,6 +1137,8 @@ func.func @while_unused_result() -> i32 {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
+// -----
+
// CHECK-LABEL: @while_cmp_lhs
func.func @while_cmp_lhs(%arg0 : i32) {
%0 = scf.while () : () -> i32 {
@@ -1152,6 +1166,8 @@ func.func @while_cmp_lhs(%arg0 : i32) {
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
+// -----
+
// CHECK-LABEL: @while_cmp_rhs
func.func @while_cmp_rhs(%arg0 : i32) {
%0 = scf.while () : () -> i32 {
@@ -1210,6 +1226,7 @@ func.func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32
+// -----
// CHECK-LABEL: @combineIfs2
func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1236,6 +1253,7 @@ func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
+// -----
// CHECK-LABEL: @combineIfs3
func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
@@ -1262,6 +1280,8 @@ func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32
+// -----
+
// CHECK-LABEL: @combineIfs4
func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
scf.if %arg0 {
@@ -1280,6 +1300,8 @@ func.func @combineIfs4(%arg0 : i1, %arg2: i64) {
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
// CHECK-NEXT: }
+// -----
+
// CHECK-LABEL: @combineIfsUsed
// CHECK-SAME: %[[arg0:.+]]: i1
func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
@@ -1310,6 +1332,8 @@ func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32
+// -----
+
// CHECK-LABEL: @combineIfsNot
// CHECK-SAME: %[[arg0:.+]]: i1
func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
@@ -1332,6 +1356,8 @@ func.func @combineIfsNot(%arg0 : i1, %arg2: i64) {
// CHECK-NEXT: "test.secondCodeTrue"() : () -> ()
// CHECK-NEXT: }
+// -----
+
// CHECK-LABEL: @combineIfsNot2
// CHECK-SAME: %[[arg0:.+]]: i1
func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
@@ -1353,6 +1379,7 @@ func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
// CHECK-NEXT: } else {
// CHECK-NEXT: "test.firstCodeTrue"() : () -> ()
// CHECK-NEXT: }
+
// -----
// CHECK-LABEL: func @propagate_into_execute_region
@@ -1403,7 +1430,6 @@ func.func @execute_region_elim() {
// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> ()
// CHECK-NEXT: }
-
// -----
// CHECK-LABEL: func @func_execute_region_elim
@@ -1439,7 +1465,6 @@ func.func @func_execute_region_elim() {
// CHECK: "test.bar"(%[[z]])
// CHECK: return
-
// -----
// CHECK-LABEL: func @func_execute_region_elim_multi_yield
More information about the Mlir-commits
mailing list