[flang-commits] [flang] [mlir] [mlir][Transforms] Dialect conversion: add `originalType` param to materialization (PR #112128)

via flang-commits flang-commits at lists.llvm.org
Sun Oct 13 03:21:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds an `originalType` parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: `originalType` is only needed for target materializations. For source/argument materializations, `originalType` always matches `outputType`. However, to keep the code base simple (i.e., reuse `MaterializationCallbackFn` for all three materializations), `originalType` is passed to all three materializations, even though it is only really needed for target materializations.

`originalType` is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the `Location` parameter right after the `OpBuilder` parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the `Location` parameter to the second position and add a `Type originalType` parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match `inputs.front().getType()`. If the original type cannot be recovered (which is unlikely), pass `Type()`.

---

Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff


31 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+2-2) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+40-16) 
- (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+4-3) 
- (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2) 
- (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+1-1) 
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-11) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+6-6) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-4) 
- (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+3-2) 
- (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+5-4) 
- (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+7-6) 
- (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+10-10) 
- (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+3-2) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+6-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+22-20) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp (+18-16) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+2-2) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+33-19) 
- (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+9-5) 
- (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+10-7) 
- (modified) mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (+3-2) 
- (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+6-4) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+11-9) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-3) 
- (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (+3-2) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/112128


More information about the flang-commits mailing list