[Mlir-commits] [mlir] 7ecc921 - [mlir][vector] Fix incorrect API usage in RewritePatterns
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 2 05:04:58 PST 2023
Author: Matthias Springer
Date: 2023-03-02T13:58:37+01:00
New Revision: 7ecc921deb1551cad4920a773e519074f974d593
URL: https://github.com/llvm/llvm-project/commit/7ecc921deb1551cad4920a773e519074f974d593
DIFF: https://github.com/llvm/llvm-project/commit/7ecc921deb1551cad4920a773e519074f974d593.diff
LOG: [mlir][vector] Fix incorrect API usage in RewritePatterns
Incorrect API usage was detected by D144552.
Differential Revision: https://reviews.llvm.org/D145153
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a81cf673dc40..a6e0b4ba52055 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4201,7 +4201,9 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
if (checkSameValueWAW(writeOp, defWrite)) {
- writeToModify.getSourceMutable().assign(defWrite.getSource());
+ rewriter.updateRootInPlace(writeToModify, [&]() {
+ writeToModify.getSourceMutable().assign(defWrite.getSource());
+ });
return success();
}
if (!isDisjointTransferIndices(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index fe3cad78e5422..22c8ea611b36f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -657,7 +657,8 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, elementWise, newOperands,
{newWarpOp.getResult(operandIndex).getType()});
- newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
+ newOp->getResult(0));
return success();
}
};
@@ -695,7 +696,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
- warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
return success();
}
};
@@ -759,7 +760,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
read.getInBoundsAttr());
- distributedVal.replaceAllUsesWith(newRead);
+ rewriter.replaceAllUsesWith(distributedVal, newRead);
return success();
}
};
@@ -855,7 +856,7 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
if (!valForwarded)
return failure();
- warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
+ rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
return success();
}
};
@@ -880,7 +881,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ broadcasted);
return success();
}
};
@@ -936,7 +938,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getPosition());
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newExtract);
return success();
}
@@ -973,7 +976,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getPosition());
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newExtract);
return success();
}
};
@@ -1031,7 +1035,8 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
newExtract =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
}
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newExtract);
return success();
}
@@ -1056,7 +1061,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
return success();
}
@@ -1104,7 +1109,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Broadcast: Simply move the vector.inserelement op out.
Value newInsert = rewriter.create<vector::InsertElementOp>(
loc, newSource, distributedVec, newPos);
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newInsert);
return success();
}
@@ -1138,7 +1144,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
builder.create<scf::YieldOp>(loc, distributedVec);
})
.getResult(0);
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
@@ -1184,7 +1190,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getPosition());
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+ newResult);
return success();
}
@@ -1263,7 +1270,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
.getResult(0);
}
- newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
@@ -1400,8 +1407,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.eraseOp(forOp);
// Replace the warpOp result coming from the original ForOp.
for (const auto &res : llvm::enumerate(resultIdx)) {
- newWarpOp.getResult(res.value())
- .replaceAllUsesWith(newForOp.getResult(res.index()));
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
+ newForOp.getResult(res.index()));
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
}
newForOp.walk([&](Operation *op) {
@@ -1494,7 +1501,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
newWarpOp.getResult(newRetIndices[1]));
}
- newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
return success();
}
More information about the Mlir-commits
mailing list