[Mlir-commits] [mlir] [mlir][Transforms] Merge 1:1 and 1:N type converters (PR #113032)

Matthias Springer llvmlistbot at llvm.org
Wed Oct 23 09:58:16 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/113032

>From 44db7a9efdbd33f5a40723a96cbe9f66f68e0970 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 19 Oct 2024 12:05:13 +0200
Subject: [PATCH 1/2] [mlir][Transforms] Merge 1:1 and 1:N type converters

---
 .../Dialect/SparseTensor/Transforms/Passes.h  |  2 +-
 .../mlir/Transforms/DialectConversion.h       | 56 ++++++++++++++-----
 .../mlir/Transforms/OneToNTypeConversion.h    | 45 +--------------
 .../ArmSME/Transforms/VectorLegalization.cpp  |  2 +-
 .../Transforms/Utils/DialectConversion.cpp    | 24 ++++++--
 .../Transforms/Utils/OneToNTypeConversion.cpp | 44 +++++----------
 .../TestOneToNTypeConversionPass.cpp          | 18 ++++--
 7 files changed, 93 insertions(+), 98 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..37da03bbe386e9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore, `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -340,9 +350,9 @@ class TypeConverter {
 
   /// The signature of the callback used to materialize a target conversion.
   ///
-  /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  /// Arguments: builder, result types, inputs, location, original type
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

>From df341e92122f9c3d3a67b28583a45107a78268a8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 23 Oct 2024 09:36:45 -0700
Subject: [PATCH 2/2] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 mlir/include/mlir/Transforms/DialectConversion.h | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 37da03bbe386e9..0638dfedd647b0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -434,18 +434,19 @@ class TypeConverter {
         // This is a 1:N target materialization. Return the produces values
         // directly.
         result = callback(builder, resultTypes, inputs, loc, originalType);
-      } else {
+      } else if constexpr (std::is_assignable<Type, T>::value) {
         // This is a 1:1 target materialization. Invoke it only if the result
         // type class of the callback matches the requested result type.
         if (T derivedType = dyn_cast<T>(resultTypes.front())) {
           // 1:1 materializations produce single values, but we store 1:N
           // target materialization functions in the type converter. Wrap the
           // result value in a SmallVector<Value>.
-          std::optional<Value> val =
-              callback(builder, derivedType, inputs, loc, originalType);
-          if (val.has_value() && *val)
-            result.push_back(*val);
+          Value val = callback(builder, derivedType, inputs, loc, originalType);
+          if (val)
+            result.push_back(val);
         }
+      } else {
+        static_assert(false, "T must be a Type or a TypeRange");
       }
       return result;
     };



More information about the Mlir-commits mailing list