[Mlir-commits] [mlir] ae9e1d1 - [mlir][SparseTensor] Fix incorrect API usage in RewritePatterns

Matthias Springer llvmlistbot at llvm.org
Thu Mar 2 09:04:31 PST 2023


Author: Matthias Springer
Date: 2023-03-02T17:59:57+01:00
New Revision: ae9e1d1df46a50a6748514ee1d7d85e7fa81890d

URL: https://github.com/llvm/llvm-project/commit/ae9e1d1df46a50a6748514ee1d7d85e7fa81890d
DIFF: https://github.com/llvm/llvm-project/commit/ae9e1d1df46a50a6748514ee1d7d85e7fa81890d.diff

LOG: [mlir][SparseTensor] Fix incorrect API usage in RewritePatterns

Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D145166

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b128669a707f8..0663bd927fe5e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -444,7 +444,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
-      op->setOperand(0, convert);
+      rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
       return success();
     }
     if (encDst) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index bc05137bcac47..1772eef57bdcc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
           forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
       rewriter.setInsertionPointToStart(forOpNew.getBody());
     } else {
-      forOp.setStep(step);
+      rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
       rewriter.setInsertionPoint(yield);
     }
     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
@@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
         // Now do some relinking (last one is not completely type safe
         // but all bad ones are removed right away). This also folds away
         // nop broadcast operations.
-        forOp.getResult(0).replaceAllUsesWith(vres);
-        forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
-        forOp.getRegionIterArg(0).replaceAllUsesWith(
-            forOpNew.getRegionIterArg(0));
+        rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
+        rewriter.replaceAllUsesWith(forOp.getInductionVar(),
+                                    forOpNew.getInductionVar());
+        rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
+                                    forOpNew.getRegionIterArg(0));
         rewriter.eraseOp(forOp);
       }
       return true;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d8fcac07466d8..575b50577e413 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
       return genIndexValue(env, indexOp.getDim());
     if (def->getBlock() == block) {
-      for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
-        def->setOperand(
-            i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+      for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
+        rewriter.updateRootInPlace(def, [&]() {
+          def->setOperand(
+              i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+        });
+      }
     }
   }
   return e;
@@ -1615,7 +1618,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       auto dstTp = RankedTensorType::get(srcTp.getShape(),
                                          srcTp.getElementType(), dstEnc);
       auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
-      env.op()->setOperand(tensor, convert);
+      rewriter.updateRootInPlace(
+          env.op(), [&]() { env.op()->setOperand(tensor, convert); });
       rewriter.setInsertionPointAfter(env.op());
       rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
       return success();


        


More information about the Mlir-commits mailing list