[Mlir-commits] [mlir] ae9e1d1 - [mlir][SparseTensor] Fix incorrect API usage in RewritePatterns
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 2 09:04:31 PST 2023
Author: Matthias Springer
Date: 2023-03-02T17:59:57+01:00
New Revision: ae9e1d1df46a50a6748514ee1d7d85e7fa81890d
URL: https://github.com/llvm/llvm-project/commit/ae9e1d1df46a50a6748514ee1d7d85e7fa81890d
DIFF: https://github.com/llvm/llvm-project/commit/ae9e1d1df46a50a6748514ee1d7d85e7fa81890d.diff
LOG: [mlir][SparseTensor] Fix incorrect API usage in RewritePatterns
Incorrect API usage was detected by D144552.
Differential Revision: https://reviews.llvm.org/D145166
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b128669a707f8..0663bd927fe5e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -444,7 +444,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- op->setOperand(0, convert);
+ rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
return success();
}
if (encDst) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index bc05137bcac47..1772eef57bdcc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
- forOp.setStep(step);
+ rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
@@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
// Now do some relinking (last one is not completely type safe
// but all bad ones are removed right away). This also folds away
// nop broadcast operations.
- forOp.getResult(0).replaceAllUsesWith(vres);
- forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
- forOp.getRegionIterArg(0).replaceAllUsesWith(
- forOpNew.getRegionIterArg(0));
+ rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
+ rewriter.replaceAllUsesWith(forOp.getInductionVar(),
+ forOpNew.getInductionVar());
+ rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
+ forOpNew.getRegionIterArg(0));
rewriter.eraseOp(forOp);
}
return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d8fcac07466d8..575b50577e413 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
return genIndexValue(env, indexOp.getDim());
if (def->getBlock() == block) {
- for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
- def->setOperand(
- i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+ for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
+ rewriter.updateRootInPlace(def, [&]() {
+ def->setOperand(
+ i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+ });
+ }
}
}
return e;
@@ -1615,7 +1618,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- env.op()->setOperand(tensor, convert);
+ rewriter.updateRootInPlace(
+ env.op(), [&]() { env.op()->setOperand(tensor, convert); });
rewriter.setInsertionPointAfter(env.op());
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
return success();
More information about the Mlir-commits
mailing list