[flang-commits] [flang] [mlir] [mlir][Transforms] Dialect conversion: add `originalType` param to materialization (PR #112128)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Sun Oct 13 03:21:17 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/112128

This commit adds an `originalType` parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: `originalType` is only needed for target materializations. For source/argument materializations, `originalType` always matches `outputType`. However, to keep the code base simple (i.e., reuse `MaterializationCallbackFn` for all three materializations), `originalType` is passed to all three materializations, even though it is only really needed for target materializations.

`originalType` is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the `Location` parameter right after the `OpBuilder` parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the `Location` parameter to the second position and add a `Type originalType` parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match `inputs.front().getType()`. If the original type cannot be recovered (which is unlikely), pass `Type()`.

>From 2c438b0a4f7d01b43ec91f2248b6f29fe8199793 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 13 Oct 2024 12:03:00 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: add `originalType`
 param to materialization

This commit adds an `originalType` parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: `originalType` is only needed for target materializations. For source/argument materializations, `originalType` always matches `outputType`. However, to keep the code base simple (i.e., reuse `MaterializationCallbackFn` for all three materializations), `originalType` is passed to all three materializations, even though it is only really needed for target materializations.

`originalType` is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the `Location` parameter right after the `OpBuilder` parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the `Location` parameter to the second position and add a `Type originalType` parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match `inputs.front().getType()`. If the original type cannot be recovered (which is unlikely), pass `Type()`.
---
 .../lib/Optimizer/CodeGen/BoxedProcedure.cpp  |  4 +-
 .../mlir/Transforms/DialectConversion.h       | 56 +++++++++++++------
 .../mlir/Transforms/OneToNTypeConversion.h    |  7 ++-
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    |  4 +-
 .../ControlFlowToSPIRV/ControlFlowToSPIRV.cpp |  2 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 22 ++++----
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  2 +-
 .../Transforms/EmulateUnsupportedFloats.cpp   | 12 ++--
 .../Bufferization/Transforms/Bufferize.cpp    | 10 ++--
 .../EmitC/Transforms/TypeConversions.cpp      |  4 +-
 .../Func/TransformOps/FuncTransformOps.cpp    |  5 +-
 .../Transforms/DecomposeCallGraphTypes.cpp    |  3 +-
 .../Dialect/Linalg/Transforms/Detensorize.cpp |  9 +--
 .../Transforms/ExtendToSupportedTypes.cpp     | 13 +++--
 .../Quant/Transforms/StripFuncQuantTypes.cpp  | 20 +++----
 .../Transforms/StructuralTypeConversions.cpp  |  2 +-
 .../Transforms/LowerABIAttributesPass.cpp     |  5 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 12 ++--
 .../Transforms/SparseIterationToScf.cpp       |  6 +-
 .../Utils/SparseTensorDescriptor.cpp          |  6 +-
 .../TransformOps/TensorTransformOps.cpp       | 42 +++++++-------
 .../Tosa/Transforms/TosaTypeConverters.cpp    | 34 +++++------
 .../Vector/Transforms/VectorLinearize.cpp     |  4 +-
 .../Transforms/Utils/DialectConversion.cpp    | 52 ++++++++++-------
 .../Transforms/Utils/OneToNTypeConversion.cpp | 14 +++--
 .../TestOneToNTypeConversionPass.cpp          | 17 +++---
 .../lib/Dialect/Arith/TestEmulateWideInt.cpp  |  5 +-
 .../Func/TestDecomposeCallGraphTypes.cpp      | 10 ++--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 20 ++++---
 .../TestTransformDialectExtension.cpp         |  6 +-
 .../lib/Transforms/TestDialectConversion.cpp  |  5 +-
 31 files changed, 238 insertions(+), 175 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -177,7 +177,8 @@ struct DecomposeCallGraphTypesForCallOp
       } else {
         // Materialize a single Value to replace the original Value.
         Value materialized = getTypeConverter()->materializeArgumentConversion(
-            rewriter, op.getLoc(), op.getType(i), decomposedValues);
+            rewriter, op.getLoc(), op.getType(i), decomposedValues,
+            op.getType(i));
         replacedValues.push_back(materialized);
       }
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..d579afaf076ebc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -28,8 +28,9 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::linalg;
 
-static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
-                                           ValueRange inputs, Location loc) {
+static Value sourceMaterializationCallback(OpBuilder &builder, Location loc,
+                                           Type type, ValueRange inputs,
+                                           Type originalType) {
   assert(inputs.size() == 1);
   auto inputType = inputs[0].getType();
   if (isa<TensorType>(inputType))
@@ -148,8 +149,8 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     // A tensor value is detensoried by extracting its element(s).
-    addTargetMaterialization([](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) -> Value {
+    addTargetMaterialization([](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) -> Value {
       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
     });
 
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
index a570ed5118ef0b..8172317cc5fcd6 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -71,12 +71,13 @@ void mlir::math::populateExtendToSupportedTypesTypeConverter(
 
         return std::nullopt;
       });
-  typeConverter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  typeConverter.addTargetMaterialization([](OpBuilder &b, Location loc,
+                                            Type target, ValueRange input,
+                                            Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::math::populateExtendToSupportedTypesConversionTarget(
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 8996eff61a39c0..65cf8cf8c37f7f 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -36,21 +36,22 @@ class QuantizedTypeConverter : public TypeConverter {
   static Type convertQuantizedType(QuantizedType quantizedType) {
     return quantizedType.getStorageType();
   }
-  
+
   static Type convertTensorType(TensorType tensorType) {
-    if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+    if (auto quantizedType =
+            dyn_cast<QuantizedType>(tensorType.getElementType()))
       return tensorType.clone(convertQuantizedType(quantizedType));
     return tensorType;
   }
 
-  static Value materializeConversion(OpBuilder &builder, Type type,
-                                     ValueRange inputs, Location loc) {
+  static Value materializeConversion(OpBuilder &builder, Location loc,
+                                     Type type, ValueRange inputs,
+                                     Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
   }
 
 public:
-
   explicit QuantizedTypeConverter() {
     addConversion([](Type type) { return type; });
     addConversion(convertQuantizedType);
@@ -63,7 +64,8 @@ class QuantizedTypeConverter : public TypeConverter {
 };
 
 // Conversion pass
-class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+class StripFuncQuantTypes
+    : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
 
   // Return whether a type is considered legal when occurring in the header of
   // a function or as an operand to a 'return' op.
@@ -74,11 +76,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
   }
 
 public:
-
   void runOnOperation() override {
-    
+
     auto moduleOp = cast<ModuleOp>(getOperation());
-    auto* context = &getContext();
+    auto *context = &getContext();
 
     QuantizedTypeConverter typeConverter;
     ConversionTarget target(*context);
@@ -111,4 +112,3 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
 
 } // namespace quant
 } // namespace mlir
-
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..f44d006a959d0c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -82,7 +82,7 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
         // 1 : N type conversion.
         Type origType = op.getResultTypes()[i - 1];
         Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
+            rewriter, op.getLoc(), origType, mappedValue, origType);
         if (!mat) {
           return rewriter.notifyMatchFailure(
               op, "Failed to materialize 1:N type conversion");
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 2024a2e5279ffc..f7c751dfd4b04b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -298,9 +298,10 @@ void LowerABIAttributesPass::runOnOperation() {
   SPIRVTypeConverter typeConverter(targetEnv);
 
   // Insert a bitcast in the case of a pointer type change.
-  typeConverter.addSourceMaterialization([](OpBuilder &builder,
+  typeConverter.addSourceMaterialization([](OpBuilder &builder, Location loc,
                                             spirv::PointerType type,
-                                            ValueRange inputs, Location loc) {
+                                            ValueRange inputs,
+                                            Type originalType) {
     if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
       return Value();
     return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 656090314d650e..949ee339d2b9f8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1452,12 +1452,12 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   });
 
   // Register some last line of defense casting logic.
-  addSourceMaterialization(
-      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
-        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
-      });
-  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
+  addSourceMaterialization([this](OpBuilder &builder, Location loc, Type type,
+                                  ValueRange inputs, Type originalType) {
+    return castToSourceType(this->targetEnv, builder, type, inputs, loc);
+  });
+  addTargetMaterialization([](OpBuilder &builder, Location loc, Type type,
+                              ValueRange inputs, Type originalType) {
     auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
     return std::optional<Value>(cast.getResult(0));
   });
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 04466d198b5b67..1bcc51877481b9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -424,9 +424,9 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
   addConversion(convertIteratorType);
   addConversion(convertIterSpaceType);
 
-  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
-                              ValueRange inputs,
-                              Location loc) -> std::optional<Value> {
+  addSourceMaterialization([](OpBuilder &builder, Location loc,
+                              IterSpaceType spTp, ValueRange inputs,
+                              Type originalType) -> std::optional<Value> {
     return builder
         .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
         .getResult(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 6ac26ad550f9f3..d4ced449894a99 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -59,9 +59,9 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
   addConversion(convertSparseTensorType);
 
   // Required by scf.for 1:N type conversion.
-  addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
-                              ValueRange inputs,
-                              Location loc) -> std::optional<Value> {
+  addSourceMaterialization([](OpBuilder &builder, Location loc,
+                              RankedTensorType tp, ValueRange inputs,
+                              Type originalType) -> std::optional<Value> {
     if (!getSparseTensorEncoding(tp))
       // Not a sparse tensor.
       return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index f911619d71227d..e272c23a0e0535 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -150,26 +150,28 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
 void transform::TypeConversionCastShapeDynamicDimsOp::
     populateTypeMaterializations(TypeConverter &converter) {
   bool ignoreDynamicInfo = getIgnoreDynamicInfo();
-  converter.addSourceMaterialization([ignoreDynamicInfo](
-                                         OpBuilder &builder, Type resultType,
-                                         ValueRange inputs,
-                                         Location loc) -> std::optional<Value> {
-    if (inputs.size() != 1) {
-      return std::nullopt;
-    }
-    Value input = inputs[0];
-    if (!ignoreDynamicInfo &&
-        !tensor::preservesStaticInformation(resultType, input.getType())) {
-      return std::nullopt;
-    }
-    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
-      return std::nullopt;
-    }
-    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
-  });
-  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
-                                        ValueRange inputs,
-                                        Location loc) -> std::optional<Value> {
+  converter.addSourceMaterialization(
+      [ignoreDynamicInfo](OpBuilder &builder, Location loc, Type resultType,
+                          ValueRange inputs,
+                          Type originalType) -> std::optional<Value> {
+        if (inputs.size() != 1) {
+          return std::nullopt;
+        }
+        Value input = inputs[0];
+        if (!ignoreDynamicInfo &&
+            !tensor::preservesStaticInformation(resultType, input.getType())) {
+          return std::nullopt;
+        }
+        if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+          return std::nullopt;
+        }
+        return builder.create<tensor::CastOp>(loc, resultType, input)
+            .getResult();
+      });
+  converter.addTargetMaterialization([](OpBuilder &builder, Location loc,
+                                        Type resultType, ValueRange inputs,
+                                        Type originalType)
+                                         -> std::optional<Value> {
     if (inputs.size() != 1) {
       return std::nullopt;
     }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
index d2650de8cd7f02..5de891548b4363 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
@@ -31,22 +31,24 @@ void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
       return {};
     return type.clone(converted);
   });
-  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                                         ValueRange inputs,
-                                         Location loc) -> std::optional<Value> {
-    if (inputs.size() != 1)
-      return std::nullopt;
+  converter.addSourceMaterialization(
+      [&](OpBuilder &builder, Location loc, Type resultType, ValueRange inputs,
+          Type originalType) -> std::optional<Value> {
+        if (inputs.size() != 1)
+          return std::nullopt;
 
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
-  });
-  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                                         ValueRange inputs,
-                                         Location loc) -> std::optional<Value> {
-    if (inputs.size() != 1)
-      return std::nullopt;
+        return builder
+            .create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+            .getResult(0);
+      });
+  converter.addTargetMaterialization(
+      [&](OpBuilder &builder, Location loc, Type resultType, ValueRange inputs,
+          Type originalType) -> std::optional<Value> {
+        if (inputs.size() != 1)
+          return std::nullopt;
 
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
-  });
+        return builder
+            .create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+            .getResult(0);
+      });
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..86c6696217084c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -473,8 +473,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
                            type.isScalable());
   });
 
-  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                            Location loc) -> Value {
+  auto materializeCast = [](OpBuilder &builder, Location loc, Type type,
+                            ValueRange inputs, Type originalType) -> Value {
     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
         !isa<VectorType>(type))
       return nullptr;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab1f48293..497b020a26f70e 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -683,10 +683,10 @@ enum MaterializationKind {
 /// conversion.
 class UnresolvedMaterializationRewrite : public OperationRewrite {
 public:
-  UnresolvedMaterializationRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl,
-      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target);
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   const TypeConverter *converter,
+                                   MaterializationKind kind, Type originalType);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,15 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return converterAndKind.getInt();
   }
 
+  Type getOriginalType() const { return originalType; }
+
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
   llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
+
+  Type originalType;
 };
 } // namespace
 
@@ -808,6 +812,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value buildUnresolvedMaterialization(MaterializationKind kind,
                                        OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
+                                       Type originalType,
                                        const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
@@ -1034,9 +1039,9 @@ void CreateOperationRewrite::rollback() {
 
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind)
+    const TypeConverter *converter, MaterializationKind kind, Type originalType)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind) {
+      converterAndKind(converter, kind), originalType(originalType) {
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
@@ -1139,7 +1144,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
           operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
-          currentTypeConverter);
+          /*originalType=*/origType, currentTypeConverter);
       mapping.map(newOperand, castValue);
       newOperand = castValue;
     }
@@ -1255,7 +1260,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
           /*inputs=*/ValueRange(),
-          /*outputType=*/origArgType, converter);
+          /*outputType=*/origArgType, /*originalType=*/origArgType, converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
@@ -1280,7 +1285,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Value argMat = buildUnresolvedMaterialization(
         MaterializationKind::Argument,
         OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*inputs=*/replArgs, origArgType, converter);
+        /*inputs=*/replArgs, /*outputType=*/origArgType,
+        /*originalType=*/origArgType, converter);
     mapping.map(origArg, argMat);
 
     Type legalOutputType;
@@ -1299,7 +1305,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (legalOutputType && legalOutputType != origArgType) {
       Value targetMat = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(argMat),
-          origArg.getLoc(), argMat, legalOutputType, converter);
+          origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
+          /*originalType=*/origArgType, converter);
       mapping.map(argMat, targetMat);
     }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1329,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, const TypeConverter *converter) {
+    ValueRange inputs, Type outputType, Type originalType,
+    const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
     return inputs.front();
@@ -1333,7 +1341,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  originalType);
   return convertOp.getResult(0);
 }
 
@@ -1381,7 +1390,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
       newValue = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
-          /*outputType=*/result.getType(), currentTypeConverter);
+          /*outputType=*/result.getType(), /*originalType=*/result.getType(),
+          currentTypeConverter);
     }
 
     // Remap, and check for any result type changes.
@@ -2400,7 +2410,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
     case MaterializationKind::Argument:
       // Try to materialize an argument conversion.
       newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands,
+          rewrite->getOriginalType());
       if (newMaterialization)
         break;
       // If an argument materialization failed, fallback to trying a target
@@ -2408,11 +2419,13 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
       [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands,
+          rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
       newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands,
+          rewrite->getOriginalType());
       break;
     }
     if (newMaterialization) {
@@ -2565,7 +2578,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
           MaterializationKind::Source, computeInsertPoint(newValue),
           originalValue.getLoc(),
           /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
-          converter);
+          /*originalType=*/originalValue.getType(), converter);
       rewriterImpl.mapping.map(originalValue, castValue);
       inverseMapping[castValue].push_back(originalValue);
       llvm::erase(inverseMapping[newValue], originalValue);
@@ -2789,9 +2802,10 @@ TypeConverter::convertSignatureArgs(TypeRange types,
 
 Value TypeConverter::materializeConversion(
     ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
-    Location loc, Type resultType, ValueRange inputs) const {
+    Location loc, Type resultType, ValueRange inputs, Type originalType) const {
   for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
-    if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+    if (std::optional<Value> result =
+            fn(builder, loc, resultType, inputs, originalType))
       return *result;
   return nullptr;
 }
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..dab440a60f2896 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -21,11 +21,12 @@ std::optional<SmallVector<Value>>
 OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc,
                                                  TypeRange resultTypes,
-                                                 Value input) const {
+                                                 Value input,
+                                                 Type originalType) const {
   for (const OneToNMaterializationCallbackFn &fn :
        llvm::reverse(oneToNTargetMaterializations)) {
     if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
+            fn(builder, loc, resultTypes, input, originalType))
       return *result;
   }
   return std::nullopt;
@@ -370,9 +371,12 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
+      // Note: The original type is unknown. We currently do not keep track of
+      // that information.
       std::optional<SmallVector<Value>> maybeResults =
           typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
+              rewriter, castOp->getLoc(), resultTypes, operands.front(),
+              /*originalType=*/Type());
       if (!maybeResults) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
@@ -388,7 +392,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
         // Source materialization.
         maybeResult = typeConverter.materializeSourceConversion(
             rewriter, castOp->getLoc(), resultTypes.front(),
-            castOp.getOperands());
+            castOp.getOperands(), resultTypes.front());
       } else {
         // Argument materialization.
         assert(castKind == getCastKindName(CastKind::Argument) &&
@@ -396,7 +400,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
         assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
         maybeResult = typeConverter.materializeArgumentConversion(
             rewriter, castOp->getLoc(), resultTypes.front(),
-            castOp.getOperands());
+            castOp.getOperands(), resultTypes.front());
       }
       if (!maybeResult.has_value() || !maybeResult.value()) {
         emitError(castOp->getLoc())
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 1ea65109bf79db..83345133eb4b28 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -148,8 +148,8 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
 static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+buildGetTupleElementOps(OpBuilder &builder, Location loc, TypeRange resultTypes,
+                        Value input, Type originalType) {
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -163,7 +163,8 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
       SmallVector<Type> flatRecursiveTypes;
       nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
       std::optional<SmallVector<Value>> resursiveValues =
-          buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
+          buildGetTupleElementOps(builder, loc, flatRecursiveTypes, element,
+                                  /*originalType=*/Type());
       if (!resursiveValues.has_value())
         return {};
       values.append(resursiveValues.value());
@@ -180,9 +181,10 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
+static std::optional<Value> buildMakeTupleOp(OpBuilder &builder, Location loc,
                                              TupleType resultType,
-                                             ValueRange inputs, Location loc) {
+                                             ValueRange inputs,
+                                             Type originalType) {
   // Build one value for each element at this nesting level.
   SmallVector<Value> elements;
   elements.reserve(resultType.getTypes().size());
@@ -201,8 +203,9 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
       inputIt += numNestedFlattenedTypes;
 
       // Recurse on the values for the nested TupleType.
-      std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
-                                                  nestedFlattenedelements, loc);
+      std::optional<Value> res =
+          buildMakeTupleOp(builder, loc, nestedTupleType,
+                           nestedFlattenedelements, /*originalType=*/Type());
       if (!res.has_value())
         return {};
 
diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
index a6678995fc6f67..3c0adbaf7ddd53 100644
--- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
+++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
@@ -58,8 +58,9 @@ struct TestEmulateWideIntPass
     // function argument and return types of the processed function.
     // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
     // casts (and vice versa) and using it insted of `llvm.bitcast`.
-    auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                         Location loc) -> std::optional<Value> {
+    auto addBitcast = [](OpBuilder &builder, Location loc, Type type,
+                         ValueRange inputs,
+                         Type originalType) -> std::optional<Value> {
       auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
       return cast->getResult(0);
     };
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 0d7dce2240f4cb..05840485ddd805 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -43,9 +43,10 @@ static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of
 /// type `resultType`. If that type is nested, each nested tuple is built
 /// recursively with another `test.make_tuple` op.
-static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
+static std::optional<Value> buildMakeTupleOp(OpBuilder &builder, Location loc,
                                              TupleType resultType,
-                                             ValueRange inputs, Location loc) {
+                                             ValueRange inputs,
+                                             Type originalType) {
   // Build one value for each element at this nesting level.
   SmallVector<Value> elements;
   elements.reserve(resultType.getTypes().size());
@@ -64,8 +65,9 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
       inputIt += numNestedFlattenedTypes;
 
       // Recurse on the values for the nested TupleType.
-      std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
-                                                  nestedFlattenedelements, loc);
+      std::optional<Value> res =
+          buildMakeTupleOp(builder, loc, nestedTupleType,
+                           nestedFlattenedelements, /*originalType=*/Type());
       if (!res.has_value())
         return {};
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3cbc307835afd7..679cd3db84a7be 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1140,9 +1140,10 @@ struct TestTypeConverter : public TypeConverter {
 
   /// Hook for materializing a conversion. This is necessary because we generate
   /// 1->N type mappings.
-  static std::optional<Value> materializeCast(OpBuilder &builder,
+  static std::optional<Value> materializeCast(OpBuilder &builder, Location loc,
                                               Type resultType,
-                                              ValueRange inputs, Location loc) {
+                                              ValueRange inputs,
+                                              Type originalType) {
     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
   }
 };
@@ -1691,9 +1692,9 @@ struct TestTypeConversionDriver
         });
 
     /// Add the legal set of type materializations.
-    converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
-                                          ValueRange inputs,
-                                          Location loc) -> Value {
+    converter.addSourceMaterialization([](OpBuilder &builder, Location loc,
+                                          Type resultType, ValueRange inputs,
+                                          Type originalType) -> Value {
       // Allow casting from F64 back to F32.
       if (!resultType.isF16() && inputs.size() == 1 &&
           inputs[0].getType().isF64())
@@ -1786,10 +1787,11 @@ struct TestTargetMaterializationWithNoUses
         return IntegerType::get(intTy.getContext(), 64);
       return intTy;
     });
-    converter.addTargetMaterialization(
-        [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
-          return builder.create<TestCastOp>(loc, type, inputs).getResult();
-        });
+    converter.addTargetMaterialization([](OpBuilder &builder, Location loc,
+                                          Type type, ValueRange inputs,
+                                          Type originalType) {
+      return builder.create<TestCastOp>(loc, type, inputs).getResult();
+    });
 
     ConversionTarget target(getContext());
     target.addIllegalOp<TestTypeChangerOp>();
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index a0a7afce66d9a1..9430ae4f149940 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -869,9 +869,9 @@ class TestTypeConverter : public TypeConverter {
     addConversion([](RankedTensorType type) -> Type {
       return MemRefType::get(type.getShape(), type.getElementType());
     });
-    auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
-                                       ValueRange inputs,
-                                       Location loc) -> std::optional<Value> {
+    auto unrealizedCastConverter =
+        [&](OpBuilder &builder, Location loc, Type resultType,
+            ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (inputs.size() != 1)
         return std::nullopt;
       return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 97fe78c35e833d..5968665e2ff910 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -44,9 +44,10 @@ struct PDLLTypeConverter : public TypeConverter {
     return success();
   }
   /// Hook for materializing a conversion.
-  static std::optional<Value> materializeCast(OpBuilder &builder,
+  static std::optional<Value> materializeCast(OpBuilder &builder, Location loc,
                                               Type resultType,
-                                              ValueRange inputs, Location loc) {
+                                              ValueRange inputs,
+                                              Type originalType) {
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   }



More information about the flang-commits mailing list