[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Make `rewriterImpl` private in `IRRewrite` (PR #84865)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 11 19:59:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit makes `rewriterImpl` private in `IRRewrite`. This ensures that only the conversion value mapping and the dialect conversion configuration can be accessed from an IR rewrite object.


---
Full diff: https://github.com/llvm/llvm-project/pull/84865.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+23-18) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e4a022b7a0288b..dbdfaeeeb28d4b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -232,6 +232,9 @@ class IRRewrite {
 
   const ConversionConfig &getConfig() const;
 
+  ConversionValueMapping &getMapping();
+
+private:
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -470,7 +473,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// live users, using the provided `findLiveUser` to search for a user that
   /// survives the conversion process.
   LogicalResult
-  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+  materializeLiveConversions(OpBuilder &builder,
+                             function_ref<Operation *(Value)> findLiveUser);
 
   void commit(RewriterBase &rewriter) override;
 
@@ -1035,6 +1039,8 @@ const ConversionConfig &IRRewrite::getConfig() const {
   return rewriterImpl.config;
 }
 
+ConversionValueMapping &IRRewrite::getMapping() { return rewriterImpl.mapping; }
+
 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
   // Inform the listener about all IR modifications that have already taken
   // place: References to the original block have been replaced with the new
@@ -1049,8 +1055,7 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
        llvm::zip_equal(origBlock->getArguments(), argInfo)) {
     // Handle the case of a 1->0 value mapping.
     if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+      if (Value newArg = getMapping().lookupOrNull(origArg, origArg.getType()))
         rewriter.replaceAllUsesWith(origArg, newArg);
       continue;
     }
@@ -1061,8 +1066,8 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
 
     // If the argument is still used, replace it with the generated cast.
     if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
+      rewriter.replaceAllUsesWith(
+          origArg, getMapping().lookupOrDefault(castValue, origArg.getType()));
     }
   }
 }
@@ -1072,23 +1077,23 @@ void BlockTypeConversionRewrite::rollback() {
 }
 
 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
-    function_ref<Operation *(Value)> findLiveUser) {
+    OpBuilder &builder, function_ref<Operation *(Value)> findLiveUser) {
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPointToStart(block);
+
   // Process the remapping for each of the original arguments.
   for (auto it : llvm::enumerate(origBlock->getArguments())) {
     BlockArgument origArg = it.value();
-    // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
-    OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
-    builder.setInsertionPointToStart(block);
 
     // If the type of this argument changed and the argument is still live, we
     // need to materialize a conversion.
-    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+    if (getMapping().lookupOrNull(origArg, origArg.getType()))
       continue;
     Operation *liveUser = findLiveUser(origArg);
     if (!liveUser)
       continue;
 
-    Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+    Value replacementValue = getMapping().lookupOrDefault(origArg);
     bool isDroppedArg = replacementValue == origArg;
     if (!isDroppedArg)
       builder.setInsertionPointAfterValue(replacementValue);
@@ -1113,13 +1118,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "see existing live user here: " << *liveUser;
       return failure();
     }
-    rewriterImpl.mapping.map(origArg, newArg);
+    getMapping().map(origArg, newArg);
   }
   return success();
 }
 
 void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.mapping.lookupOrNull(value);
+  Value repl = getMapping().lookupOrNull(value);
   assert(repl && "expected that value is mapped");
 
   if (isa<BlockArgument>(repl)) {
@@ -1138,7 +1143,7 @@ void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
   });
 }
 
-void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); }
+void ReplaceAllUsesRewrite::rollback() { getMapping().erase(value); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
@@ -1147,7 +1152,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   // Compute replacement values.
   SmallVector<Value> replacements =
       llvm::map_to_vector(op->getResults(), [&](OpResult result) {
-        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+        return getMapping().lookupOrNull(result, result.getType());
       });
 
   // Notify the listener that the operation is about to be replaced.
@@ -1179,7 +1184,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
 
 void ReplaceOperationRewrite::rollback() {
   for (auto result : op->getResults())
-    rewriterImpl.mapping.erase(result);
+    getMapping().erase(result);
 }
 
 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1198,7 +1203,7 @@ void CreateOperationRewrite::rollback() {
 void UnresolvedMaterializationRewrite::rollback() {
   if (getMaterializationKind() == MaterializationKind::Target) {
     for (Value input : op->getOperands())
-      rewriterImpl.mapping.erase(input);
+      getMapping().erase(input);
   }
   op->erase();
 }
@@ -2721,7 +2726,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     if (auto *blockTypeConversionRewrite =
             dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
       if (failed(blockTypeConversionRewrite->materializeLiveConversions(
-              findLiveUser)))
+              rewriter, findLiveUser)))
         return failure();
   }
   return success();

``````````

</details>


https://github.com/llvm/llvm-project/pull/84865


More information about the llvm-branch-commits mailing list