[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Unify materialization of value replacements (PR #108381)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Sep 12 05:55:40 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/108381

PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`.

This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterated over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.

>From 1f215ac7861a76f653c9911a31bf484a5fd6dac4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 12 Sep 2024 14:49:23 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Unify materialization
 of value replacements

PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`.

This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
---
 .../Transforms/Utils/DialectConversion.cpp    | 134 ++++++------------
 .../VectorToSPIRV/vector-to-spirv.mlir        |   4 +-
 2 files changed, 44 insertions(+), 94 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ed15b571f01883..0556b4ab833c30 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2336,17 +2336,6 @@ struct OperationConverter {
   /// remaining artifacts and complete the conversion.
   LogicalResult finalize(ConversionPatternRewriter &rewriter);
 
-  /// Legalize the types of converted block arguments.
-  LogicalResult
-  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
-                                 ConversionPatternRewriterImpl &rewriterImpl);
-
-  /// Legalize the types of converted op results.
-  LogicalResult legalizeConvertedOpResultTypes(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Dialect conversion configuration.
   ConversionConfig config;
 
@@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   return success();
 }
 
-LogicalResult
-OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
-  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
-    return failure();
-  DenseMap<Value, SmallVector<Value>> inverseMapping =
-      rewriterImpl.mapping.getInverse();
-  if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
-                                            inverseMapping)))
-    return failure();
-  return success();
-}
-
 /// Finds a user of the given value, or of any other value that the given value
 /// replaced, that was not replaced in the conversion process.
 static Operation *findLiveUserOfReplaced(
@@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced(
   return nullptr;
 }
 
-LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // Process requested operation replacements.
-  for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
-    auto *opReplacement =
-        dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
-    if (!opReplacement)
-      continue;
-    Operation *op = opReplacement->getOperation();
-    for (OpResult result : op->getResults()) {
-      // If the type of this op result changed and the result is still live,
-      // we need to materialize a conversion.
-      if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+  if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+    return std::make_pair(opRewrite->getOperation()->getResults(),
+                          opRewrite->getConverter());
+  if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+    return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
+                          blockRewrite->getConverter());
+  return std::make_pair(ValueRange(), nullptr);
+}
+
+LogicalResult
+OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  DenseMap<Value, SmallVector<Value>> inverseMapping =
+      rewriterImpl.mapping.getInverse();
+
+  // Process requested value replacements.
+  for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+    ValueRange replacedValues;
+    const TypeConverter *converter;
+    std::tie(replacedValues, converter) =
+        getReplacedValues(rewriterImpl.rewrites[i].get());
+    for (Value originalValue : replacedValues) {
+      // If the type of this value changed and the value is still live, we need
+      // to materialize a conversion.
+      if (rewriterImpl.mapping.lookupOrNull(originalValue,
+                                            originalValue.getType()))
         continue;
       Operation *liveUser =
-          findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+          findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
       if (!liveUser)
         continue;
 
-      // Legalize this result.
-      Value newValue = rewriterImpl.mapping.lookupOrNull(result);
+      // Legalize this value replacement.
+      Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
       assert(newValue && "replacement value not found");
       Value castValue = rewriterImpl.buildUnresolvedMaterialization(
-          MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
-          /*inputs=*/newValue, /*outputType=*/result.getType(),
-          opReplacement->getConverter());
-      rewriterImpl.mapping.map(result, castValue);
-      inverseMapping[castValue].push_back(result);
-      llvm::erase(inverseMapping[newValue], result);
+          MaterializationKind::Source, computeInsertPoint(newValue),
+          originalValue.getLoc(),
+          /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+          converter);
+      rewriterImpl.mapping.map(originalValue, castValue);
+      inverseMapping[castValue].push_back(originalValue);
+      llvm::erase(inverseMapping[newValue], originalValue);
     }
   }
 
   return success();
 }
 
-LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl) {
-  // Functor used to check if all users of a value will be dead after
-  // conversion.
-  // TODO: This should probably query the inverse mapping, same as in
-  // `legalizeConvertedOpResultTypes`.
-  auto findLiveUser = [&](Value val) {
-    auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
-      return rewriterImpl.isOpIgnored(user);
-    });
-    return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
-  };
-  // Note: `rewrites` may be reallocated as the loop is running.
-  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
-       ++i) {
-    auto &rewrite = rewriterImpl.rewrites[i];
-    if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
-      // Process the remapping for each of the original arguments.
-      for (Value origArg :
-           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
-        // If the type of this argument changed and the argument is still live,
-        // we need to materialize a conversion.
-        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-          continue;
-        Operation *liveUser = findLiveUser(origArg);
-        if (!liveUser)
-          continue;
-
-        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
-        assert(replacementValue && "replacement value not found");
-        Value repl = rewriterImpl.buildUnresolvedMaterialization(
-            MaterializationKind::Source, computeInsertPoint(replacementValue),
-            origArg.getLoc(), /*inputs=*/replacementValue,
-            /*outputType=*/origArg.getType(),
-            blockTypeConversionRewrite->getConverter());
-        rewriterImpl.mapping.map(origArg, repl);
-      }
-    }
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Reconcile Unrealized Casts
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8570bdaf4247f..25ec5d0159bd5d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -558,8 +558,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
 
 // CHECK-LABEL: func @deinterleave_scalar
 // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//   CHECK-DAG: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//   CHECK-DAG: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
 //   CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
 //   CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
 //       CHECK: return %[[CAST0]], %[[CAST1]]



More information about the llvm-branch-commits mailing list