[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: add `originalType` param to materializations (PR #112128)

Matthias Springer llvmlistbot at llvm.org
Sun Oct 13 05:07:22 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/112128

>From 80bc49fb4b995929d149521f7b162de086ea68bb Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 13 Oct 2024 13:58:42 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: add originalType param
 to materializations v2

---
 .../mlir/Transforms/DialectConversion.h       | 108 ++++++++++++------
 .../Transforms/Utils/DialectConversion.cpp    |  83 +++++++++++---
 2 files changed, 136 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..45ad6f8586daae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -138,7 +138,8 @@ class TypeConverter {
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
+  /// to any of the following forms (where `T` is a class derived from `Type`):
+  ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
   ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@ class TypeConverter {
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
-  ///                                  ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// Note: Target materializations may optionally accept an additional Type
+  /// parameter, which is the original type of the SSA value.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, users can optionally take the original
+  /// type. This type may be different from the type of the input. For example,
+  /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
+  /// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
+  /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
+  /// that "P2" determines that the legalized type of "t1" is "t3", which may
+  /// be different from "t2". In this example, the target materialization
+  /// will be invoked with: outputType = "t3", inputs = "v2",
+  // originalType = "t1". Note  that the original type "t1" cannot be recovered
+  /// from just "t3" and "v2"; that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
+        wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
   }
 
   /// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
-    return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                      Type resultType, ValueRange inputs) const;
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
-    return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                    Type resultType, ValueRange inputs) const;
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
-    return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@ class TypeConverter {
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
       Type, SmallVectorImpl<Type> &)>;
 
-  /// The signature of the callback used to materialize a conversion.
+  /// The signature of the callback used to materialize a source/argument
+  /// conversion.
+  ///
+  /// Arguments: builder, result type, inputs, location
   using MaterializationCallbackFn = std::function<std::optional<Value>(
       OpBuilder &, Type, ValueRange, Location)>;
 
+  /// The signature of the callback used to materialize a target conversion.
+  ///
+  /// Arguments: builder, result type, inputs, location, original type
+  using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
+      OpBuilder &, Type, ValueRange, Location, Type)>;
+
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
       std::function<AttributeConversionResult(Type, Attribute)>;
 
-  /// Attempt to materialize a conversion using one of the provided
-  /// materialization functions.
-  Value
-  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
-                        OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
-
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
   /// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@ class TypeConverter {
     cachedMultiConversions.clear();
   }
 
-  /// Generate a wrapper for the given materialization callback. The callback
-  /// may take any subclass of `Type` and the wrapper will check for the target
-  /// type to be of the expected class before calling the callback.
+  /// Generate a wrapper for the given argument/source materialization
+  /// callback. The callback may take any subclass of `Type` and the
+  /// wrapper will check for the target type to be of the expected class
+  /// before calling the callback.
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@ class TypeConverter {
     };
   }
 
+  /// Generate a wrapper for the given target materialization callback.
+  /// The callback may take any subclass of `Type` and the wrapper will check
+  /// for the target type to be of the expected class before calling the
+  /// callback.
+  ///
+  /// With callback of form:
+  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  template <typename T, typename FnT>
+  std::enable_if_t<
+      std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
+      TargetMaterializationCallbackFn>
+  wrapTargetMaterialization(FnT &&callback) const {
+    return [callback = std::forward<FnT>(callback)](
+               OpBuilder &builder, Type resultType, ValueRange inputs,
+               Location loc, Type originalType) -> std::optional<Value> {
+      if (T derivedType = dyn_cast<T>(resultType))
+        return callback(builder, derivedType, inputs, loc, originalType);
+      return std::nullopt;
+    };
+  }
+  /// With callback of form:
+  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  template <typename T, typename FnT>
+  std::enable_if_t<
+      std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
+      TargetMaterializationCallbackFn>
+  wrapTargetMaterialization(FnT &&callback) const {
+    return wrapTargetMaterialization<T>(
+        [callback = std::forward<FnT>(callback)](
+            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
+            Type originalType) -> std::optional<Value> {
+          return callback(builder, resultType, inputs, loc);
+        });
+  }
+
   /// Generate a wrapper for the given memory space conversion callback. The
   /// callback may take any subclass of `Attribute` and the wrapper will check
   /// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@ class TypeConverter {
   /// The list of registered materialization functions.
   SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
-  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
+  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
 
   /// The list of registered type attribute conversion functions.
   SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab1f48293..1baddd881f6aa2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -683,10 +683,10 @@ enum MaterializationKind {
 /// conversion.
 class UnresolvedMaterializationRewrite : public OperationRewrite {
 public:
-  UnresolvedMaterializationRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl,
-      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target);
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   const TypeConverter *converter,
+                                   MaterializationKind kind, Type originalType);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return converterAndKind.getInt();
   }
 
+  /// Return the original type of the SSA value.
+  Type getOriginalType() const { return originalType; }
+
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
   llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
+
+  /// The original type of the SSA value. Only used for target
+  /// materializations.
+  Type originalType;
 };
 } // namespace
 
@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value buildUnresolvedMaterialization(MaterializationKind kind,
                                        OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
+                                       Type originalType,
                                        const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
 
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind)
+    const TypeConverter *converter, MaterializationKind kind, Type originalType)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind) {
+      converterAndKind(converter, kind), originalType(originalType) {
+  assert(!originalType ||
+         kind == MaterializationKind::Target &&
+             "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
           operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
-          currentTypeConverter);
+          /*originalType=*/origType, currentTypeConverter);
       mapping.map(newOperand, castValue);
       newOperand = castValue;
     }
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
           /*inputs=*/ValueRange(),
-          /*outputType=*/origArgType, converter);
+          /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Value argMat = buildUnresolvedMaterialization(
         MaterializationKind::Argument,
         OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*inputs=*/replArgs, origArgType, converter);
+        /*inputs=*/replArgs, /*outputType=*/origArgType,
+        /*originalType=*/Type(), converter);
     mapping.map(origArg, argMat);
 
     Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (legalOutputType && legalOutputType != origArgType) {
       Value targetMat = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(argMat),
-          origArg.getLoc(), argMat, legalOutputType, converter);
+          origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
+          /*originalType=*/origArgType, converter);
       mapping.map(argMat, targetMat);
     }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, const TypeConverter *converter) {
+    ValueRange inputs, Type outputType, Type originalType,
+    const TypeConverter *converter) {
+  assert(!originalType ||
+         kind == MaterializationKind::Target &&
+             "original type is valid only for target materializations");
+
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
     return inputs.front();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  originalType);
   return convertOp.getResult(0);
 }
 
@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
       newValue = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
-          /*outputType=*/result.getType(), currentTypeConverter);
+          /*outputType=*/result.getType(), /*originalType=*/Type(),
+          currentTypeConverter);
     }
 
     // Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
       [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands,
+          rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
       newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
           MaterializationKind::Source, computeInsertPoint(newValue),
           originalValue.getLoc(),
           /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
-          converter);
+          /*originalType=*/Type(), converter);
       rewriterImpl.mapping.map(originalValue, castValue);
       inverseMapping[castValue].push_back(originalValue);
       llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
   return success();
 }
 
-Value TypeConverter::materializeConversion(
-    ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
-    Location loc, Type resultType, ValueRange inputs) const {
-  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
+Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
+                                                   Location loc,
+                                                   Type resultType,
+                                                   ValueRange inputs) const {
+  for (const MaterializationCallbackFn &fn :
+       llvm::reverse(argumentMaterializations))
+    if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+      return *result;
+  return nullptr;
+}
+
+Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
+                                                 Location loc, Type resultType,
+                                                 ValueRange inputs) const {
+  for (const MaterializationCallbackFn &fn :
+       llvm::reverse(sourceMaterializations))
     if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
       return *result;
   return nullptr;
 }
 
+Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc, Type resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType) const {
+  for (const TargetMaterializationCallbackFn &fn :
+       llvm::reverse(targetMaterializations))
+    if (std::optional<Value> result =
+            fn(builder, resultType, inputs, loc, originalType))
+      return *result;
+  return nullptr;
+}
+
 std::optional<TypeConverter::SignatureConversion>
 TypeConverter::convertBlockSignature(Block *block) const {
   SignatureConversion conversion(block->getNumArguments());



More information about the Mlir-commits mailing list