[Mlir-commits] [mlir] 7ecc921 - [mlir][vector] Fix incorrect API usage in RewritePatterns

Matthias Springer llvmlistbot at llvm.org
Thu Mar 2 05:04:58 PST 2023


Author: Matthias Springer
Date: 2023-03-02T13:58:37+01:00
New Revision: 7ecc921deb1551cad4920a773e519074f974d593

URL: https://github.com/llvm/llvm-project/commit/7ecc921deb1551cad4920a773e519074f974d593
DIFF: https://github.com/llvm/llvm-project/commit/7ecc921deb1551cad4920a773e519074f974d593.diff

LOG: [mlir][vector] Fix incorrect API usage in RewritePatterns

Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D145153

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a81cf673dc40..a6e0b4ba52055 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4201,7 +4201,9 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
         writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
     while (defWrite) {
       if (checkSameValueWAW(writeOp, defWrite)) {
-        writeToModify.getSourceMutable().assign(defWrite.getSource());
+        rewriter.updateRootInPlace(writeToModify, [&]() {
+          writeToModify.getSourceMutable().assign(defWrite.getSource());
+        });
         return success();
       }
       if (!isDisjointTransferIndices(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index fe3cad78e5422..22c8ea611b36f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -657,7 +657,8 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Operation *newOp = cloneOpWithOperandsAndTypes(
         rewriter, loc, elementWise, newOperands,
         {newWarpOp.getResult(operandIndex).getType()});
-    newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
+                                newOp->getResult(0));
     return success();
   }
 };
@@ -695,7 +696,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = warpOp.getLoc();
     rewriter.setInsertionPointAfter(warpOp);
     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
-    warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
     return success();
   }
 };
@@ -759,7 +760,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
         read.getInBoundsAttr());
-    distributedVal.replaceAllUsesWith(newRead);
+    rewriter.replaceAllUsesWith(distributedVal, newRead);
     return success();
   }
 };
@@ -855,7 +856,7 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
     if (!valForwarded)
       return failure();
-    warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
+    rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
     return success();
   }
 };
@@ -880,7 +881,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = rewriter.create<vector::BroadcastOp>(
         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                broadcasted);
     return success();
   }
 };
@@ -936,7 +938,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // Extract from distributed vector.
       Value newExtract = rewriter.create<vector::ExtractOp>(
           loc, distributedVec, extractOp.getPosition());
-      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+      rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                  newExtract);
       return success();
     }
 
@@ -973,7 +976,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Extract from distributed vector.
     Value newExtract = rewriter.create<vector::ExtractOp>(
         loc, distributedVec, extractOp.getPosition());
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                newExtract);
     return success();
   }
 };
@@ -1031,7 +1035,8 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
         newExtract =
             rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
       }
-      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+      rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                  newExtract);
       return success();
     }
 
@@ -1056,7 +1061,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Shuffle the extracted value to all lanes.
     Value shuffled = warpShuffleFromIdxFn(
         loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled);
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
     return success();
   }
 
@@ -1104,7 +1109,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // Broadcast: Simply move the vector.inserelement op out.
       Value newInsert = rewriter.create<vector::InsertElementOp>(
           loc, newSource, distributedVec, newPos);
-      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert);
+      rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                  newInsert);
       return success();
     }
 
@@ -1138,7 +1144,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
                   builder.create<scf::YieldOp>(loc, distributedVec);
                 })
             .getResult(0);
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
     return success();
   }
 };
@@ -1184,7 +1190,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
       Value newResult = rewriter.create<vector::InsertOp>(
           loc, distributedSrc, distributedDest, insertOp.getPosition());
-      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+      rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                  newResult);
       return success();
     }
 
@@ -1263,7 +1270,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
                       .getResult(0);
     }
 
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
     return success();
   }
 };
@@ -1400,8 +1407,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.eraseOp(forOp);
     // Replace the warpOp result coming from the original ForOp.
     for (const auto &res : llvm::enumerate(resultIdx)) {
-      newWarpOp.getResult(res.value())
-          .replaceAllUsesWith(newForOp.getResult(res.index()));
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
+                                  newForOp.getResult(res.index()));
       newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
     }
     newForOp.walk([&](Operation *op) {
@@ -1494,7 +1501,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
           rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
           newWarpOp.getResult(newRetIndices[1]));
     }
-    newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
     return success();
   }
 


        


More information about the Mlir-commits mailing list