[llvm-branch-commits] [flang] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 20 05:59:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-func

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit updates the internal `ConversionValueMapping` data structure in the dialect conversion driver to support 1:N replacements. This is the last major commit for adding 1:N support to the dialect conversion driver.

Since #<!-- -->116470, the infrastructure already supports 1:N replacements. But the `ConversionValueMapping` still stored 1:1 value mappings. To that end, the driver inserted temporary argument materializations (converting N SSA values into 1 value). This is no longer the case. Argument materializations are now entirely gone. (They will be deleted from the type converter after some time, when we delete the old 1:N dialect conversion driver.)

Note for LLVM integration: Replace all occurrences of `addArgumentMaterialization` (except for 1:N dialect conversion passes) with `addSourceMaterialization`.

Depends on #<!-- -->117513.


---

Patch is 46.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116524.diff


14 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (-1) 
- (modified) mlir/docs/DialectConversion.md (+5-30) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+7-11) 
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+6-8) 
- (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (-1) 
- (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (-1) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+206-226) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+2-5) 
- (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+1-1) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (-1) 
- (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (-1) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 1bb91d252529f0..104ae7408b80c1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
     addConversion([&](TypeDescType ty) {
       return TypeDescType::get(convertType(ty.getOfTy()));
     });
-    addArgumentMaterialization(materializeProcedure);
     addSourceMaterialization(materializeProcedure);
     addTargetMaterialization(materializeProcedure);
   }
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c7515..abacd5a82c61eb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
 type safety during the conversion process. There are several types of
 materializations depending on the situation.
 
-*   Argument Materialization
-
-    -   An argument materialization is used when converting the type of a block
-        argument during a [signature conversion](#region-signature-conversion).
-        The new block argument types are specified in a `SignatureConversion`
-        object. An original block argument can be converted into multiple
-        block arguments, which is not supported everywhere in the dialect
-        conversion. (E.g., adaptors support only a single replacement value for
-        each original value.) Therefore, an argument materialization is used to
-        convert potentially multiple new block arguments back into a single SSA
-        value. An argument materialization is also used when replacing an op
-        result with multiple values.
-
 *   Source Materialization
 
     -   A source materialization is used when a value was replaced with a value
@@ -343,17 +330,6 @@ class TypeConverter {
   /// Materialization functions must be provided when a type conversion may
   /// persist after the conversion has finished.
 
-  /// This method registers a materialization that will be called when
-  /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value
-  /// with the old argument type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
-  void addArgumentMaterialization(FnT &&callback) {
-    argumentMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
-  }
-
   /// This method registers a materialization that will be called when
   /// converting a replacement value back to its original source type.
   /// This is used when some uses of the original value persist beyond the main
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
 To convert the types of block arguments within a Region, a custom hook on the
 `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
 uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
 
 To convert the signature of just one given block, the
 `applySignatureConversion` hook can be used.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..9a6975dcf8dfae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ class TypeConverter {
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: Argument materializations are used only with the 1:N dialect
+  /// conversion driver. The 1:N dialect conversion driver will be removed soon
+  /// and so will be argument materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   void replaceOp(Operation *op, Operation *newOp) override;
 
   /// Replace the given operation with the new value ranges. The number of op
-  /// results and value ranges must match. If an original SSA value is replaced
-  /// by multiple SSA values (i.e., a value range has more than 1 element), the
-  /// conversion driver will insert an argument materialization to convert the
-  /// N SSA values back into 1 SSA value of the original type. The given
-  /// operation is erased.
-  ///
-  /// Note: The argument materialization is a workaround until we have full 1:N
-  /// support in the dialect conversion. (It is going to disappear from both
-  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
 
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
+  /// If set to "true", the dialect conversion attempts to build source/target
+  /// materializations through the type converter API in lieu of
   /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
   /// at least one materialization could not be built.
   ///
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..d27b557736c924 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -189,9 +189,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
                                           UnrankedMemRefType resultType,
                                           ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
+    // A source materialization must return a value of type `resultType`, so
+    // insert a cast from the memref descriptor type (!llvm.struct) to the
+    // original memref type.
     Value packed =
         packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
     if (!packed)
@@ -223,7 +223,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   auto rankedMemRefMaterialization = [&](OpBuilder &builder,
                                          MemRefType resultType,
                                          ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type `resultType`,
+    // A source materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
     Value packed =
@@ -234,11 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
+  // Source materializations convert from the new block argument types
+  // (e.g., multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
   addSourceMaterialization(unrakedMemRefMaterialization);
   addSourceMaterialization(rankedMemRefMaterialization);
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..61bc5022893741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 51686646a0a2fc..ea169a1df42b6a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Given two insertion points in the same block, choose the later one.
+static OpBuilder::InsertPoint
+chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a,
+                              OpBuilder::InsertPoint b) {
+  assert(a.getBlock() == b.getBlock() && "expected same block");
+  Block *block = a.getBlock();
+  if (a.getPoint() == block->begin())
+    return b;
+  if (b.getPoint() == block->begin())
+    return a;
+  if (a.getPoint()->isBeforeInBlock(&*b.getPoint()))
+    return b;
+  return a;
+}
+
+/// Helper function that chooses the insertion point among the two given ones
+/// that is later.
+// TODO: Extend DominanceInfo API to work with block iterators.
+static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a,
+                                                     OpBuilder::InsertPoint b) {
+  // Case 1: Same block.
+  if (a.getBlock() == b.getBlock())
+    return chooseLaterInsertPointInBlock(a, b);
+
+  // Case 2: Different block, but same region.
+  if (a.getBlock()->getParent() == b.getBlock()->getParent()) {
+    DominanceInfo domInfo;
+    if (domInfo.properlyDominates(a.getBlock(), b.getBlock()))
+      return b;
+    if (domInfo.properlyDominates(b.getBlock(), a.getBlock()))
+      return a;
+    // Neither of the two blocks dominante each other.
+    llvm_unreachable("unable to find valid insertion point");
+  }
+
+  // Case 3: b's region contains a: choose a.
+  if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion(
+          *a.getPoint()->getParentOp()))
+    return a;
+
+  // Case 4: a's region contains b: choose b.
+  if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion(
+          *b.getPoint()->getParentOp()))
+    return b;
+
+  // Neither of the two operations contain each other.
+  llvm_unreachable("unable to find valid insertion point");
+}
+
 /// Helper function that computes an insertion point where the given value is
 /// defined and can be used without a dominance violation.
 static OpBuilder::InsertPoint computeInsertPoint(Value value) {
@@ -63,11 +113,36 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given values are
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front())
+    pt = chooseLaterInsertPoint(pt, computeInsertPoint(v));
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A vector of SSA values, optimized for the most common case of a single
+/// value.
+using ValueVector = SmallVector<Value, 1>;
+
 namespace {
+
+/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
+struct ValueVectorMapInfo {
+  static ValueVector getEmptyKey() { return ValueVector{}; }
+  static ValueVector getTombstoneKey() { return ValueVector{}; }
+  static ::llvm::hash_code getHashValue(ValueVector val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -75,68 +150,103 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped value with the desired type in the
+  /// Lookup the most recently mapped values with the desired types in the
   /// mapping.
   ///
   /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
-  ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+  /// - If the desired type range is empty, simply return the most recently
+  ///   mapped values.
+  /// - If there is no mapping to the desired types, also return the most
+  ///   recently mapped values.
+  /// - If there is no mapping for the given values at all, return the given
+  ///   values.
+  ValueVector lookupOrDefault(ValueVector from,
+                              TypeRange desiredTypes = {}) const;
+
+  /// Lookup the given values within the map, or return an empty vector if the
+  /// values are not mapped. If they are mapped, this follows the same behavior
+  /// as `lookupOrDefault`.
+  ValueVector lookupOrNull(const ValueVector &from,
+                           TypeRange desiredTypes = {}) const;
 
   /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
+  void map(const ValueVector &oldVal, const ValueVector &newVal) {
     LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
+      ValueVector next = newVal;
+      while (true) {
+        assert(next != oldVal && "inserting cyclic mapping");
+        auto it = mapping.find(next);
+        if (it == mapping.end())
+          break;
+        next = it->second;
+      }
     });
-    mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+    mapping[oldVal] = newVal;
+    for (Value v : newVal)
+      mappedTo.insert(v);
   }
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Drop the last mapping for the given values.
+  void erase(ValueVector value) { mapping.erase(value); }
 
 private:
   /// Current value mappings.
-  IRMapping mapping;
+  DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
+ValueVector
+ConversionValueMapping::lookupOrDefault(ValueVector from,
+                                        TypeRange desiredTypes) const {
+  // Try to find the deepest values that have the desired types. If there is no
+  // such mapping, simply return the deepest values.
+  ValueVector desiredValue;
   do {
-    if (!desiredType || from.getType() == desiredType)
+    // Store the current value if the types match.
+    if (desiredTypes.empty() || TypeRange(from) == desiredTypes)
       desiredValue = from;
 
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
+    // If possible, Replace each value with (one or multiple) mapped values.
+    ValueVector next;
+    for (Value v : from) {
+      auto it = mapping.find({v});
+      if (it != mapping.end()) {
+        llvm::append_range(next, it->second);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != from) {
+      // If at least one value was replaced, continue the lookup from there.
+      from = next;
+    ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116524


More information about the llvm-branch-commits mailing list