[Mlir-commits] [mlir] 12874e9 - [mlir][NFC] Add helper for common pattern of replaceAllUsesExcept

Sean Silva llvmlistbot at llvm.org
Thu May 13 12:42:18 PDT 2021


Author: Sean Silva
Date: 2021-05-13T12:42:10-07:00
New Revision: 12874e93a15219ccfaff42a0536b2b5368c6f304

URL: https://github.com/llvm/llvm-project/commit/12874e93a15219ccfaff42a0536b2b5368c6f304
DIFF: https://github.com/llvm/llvm-project/commit/12874e93a15219ccfaff42a0536b2b5368c6f304.diff

LOG: [mlir][NFC] Add helper for common pattern of replaceAllUsesExcept

This covers the extremely common case of replacing all uses of a Value
with a new op that is itself a user of the original Value.

This should also be a little bit more efficient than the
`SmallPtrSet<Operation *, 1>{op}` idiom that was being used before.

Differential Revision: https://reviews.llvm.org/D102373

Added: 
    

Modified: 
    mlir/include/mlir/IR/Value.h
    mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
    mlir/lib/IR/Value.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 2a0bd67c1cb0..bd80b2b582d1 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -166,6 +166,11 @@ class Value {
   replaceAllUsesExcept(Value newValue,
                        const SmallPtrSetImpl<Operation *> &exceptions) const;
 
+  /// Replace all uses of 'this' value with 'newValue', updating anything in the
+  /// IR that uses 'this' to use the other value instead except if the user is
+  /// 'exceptedUser'.
+  void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const;
+
   /// Replace all uses of 'this' value with 'newValue' if the given callback
   /// returns true.
   void replaceUsesWithIf(Value newValue,

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
index 8653bcf2ad63..1a785a03df76 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
@@ -72,7 +72,7 @@ void mlir::normalizeAffineParallel(AffineParallelOp op) {
     applyOperands.push_back(iv);
     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
-    iv.replaceAllUsesExcept(apply, SmallPtrSet<Operation *, 1>{apply});
+    iv.replaceAllUsesExcept(apply, apply);
   }
 
   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
@@ -181,8 +181,7 @@ static void normalizeAffineFor(AffineForOp op) {
   AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
                                    origLbMap.getNumSymbols(), newIVExpr);
   Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
-  op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0),
-                                            SmallPtrSet<Operation *, 1>{newIV});
+  op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 79903c06883d..4c06d3dc3504 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -191,8 +191,7 @@ static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
       AffineApplyOp applyOp = builder.create<AffineApplyOp>(
           indexOp.getLoc(), index + offset,
           ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
-      indexOp.getResult().replaceAllUsesExcept(
-          applyOp, SmallPtrSet<Operation *, 1>{applyOp});
+      indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
     }
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0479ab654311..bdc1d7097ccd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -155,8 +155,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
     AffineApplyOp applyOp = b.create<AffineApplyOp>(
         indexOp.getLoc(), index + iv,
         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
-    indexOp.getResult().replaceAllUsesExcept(
-        applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
+    indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
   }
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index cdb4afa929cf..8282c0771f30 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -121,8 +121,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
     Value inner_index = std::get<0>(ivs);
     AddIOp newIndex =
         b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
-    inner_index.replaceAllUsesExcept(
-        newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
+    inner_index.replaceAllUsesExcept(newIndex, newIndex);
   }
 
   op.erase();

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index e28ab9ba470d..a4baa9311001 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -63,12 +63,23 @@ void Value::replaceAllUsesWith(Value newValue) const {
 /// listed in 'exceptions' .
 void Value::replaceAllUsesExcept(
     Value newValue, const SmallPtrSetImpl<Operation *> &exceptions) const {
-  for (auto &use : llvm::make_early_inc_range(getUses())) {
+  for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
     if (exceptions.count(use.getOwner()) == 0)
       use.set(newValue);
   }
 }
 
+/// Replace all uses of 'this' value with 'newValue', updating anything in the
+/// IR that uses 'this' to use the other value instead except if the user is
+/// 'exceptedUser'.
+void Value::replaceAllUsesExcept(Value newValue,
+                                 Operation *exceptedUser) const {
+  for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
+    if (use.getOwner() != exceptedUser)
+      use.set(newValue);
+  }
+}
+
 /// Replace all uses of 'this' value with 'newValue' if the given callback
 /// returns true.
 void Value::replaceUsesWithIf(Value newValue,


        


More information about the Mlir-commits mailing list