[Mlir-commits] [flang] [mlir] [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (PR #97213)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 13 08:37:12 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/97213

>From d4963ded9d778f8a2ba90f86d7a6be08fee61f78 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 13 Jul 2024 14:32:06 +0200
Subject: [PATCH 1/3] [flang] Remove materialization workaround in type
 converter

This change is in preparation of #97903, which adds extra checks for materializations: it is now enforced that they produce an SSA value of the correct type, so the current workaround no longer works.

For `fir.has_value` the fix is simple: no target materializations on the operands are performed if the lowering patterns is initialized without a type converter. For `cg::XEmboxOp`, the existing workaround that skips `unrealized_conversion_cast` ops can be generalized. (This is still a workaround.)

Also remove the lowering pattern for `unrealized_conversion_cast`. This pattern has no effect because `unrealized_conversion_cast` ops that are inserted by the dialect conversion framework are never matched by the pattern driver.
---
 flang/include/flang/Tools/CLOptions.inc       |  5 ++
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 67 +++++++------------
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 34 ----------
 flang/test/Fir/basic-program.fir              |  1 +
 4 files changed, 30 insertions(+), 77 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 7f2910c5cfd3c..7df5044949463 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -9,6 +9,7 @@
 /// This file defines some shared command-line options that can be used when
 /// debugging the test tools. This file must be included into the tool.
 
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Pass/PassManager.h"
@@ -223,6 +224,10 @@ inline void addFIRToLLVMPass(
   options.forceUnifiedTBAATree = useOldAliasTags;
   addPassConditionally(pm, disableFirToLlvmIr,
       [&]() { return fir::createFIRToLLVMPass(options); });
+  // The dialect conversion framework may leave dead unrealized_conversion_cast
+  // ops behind, so run reconcile-unrealized-casts to clean them up.
+  addPassConditionally(pm, disableFirToLlvmIr,
+      [&]() { return mlir::createReconcileUnrealizedCastsPass(); });
 }
 
 inline void addLLVMDialectToLLVMPass(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 7483acfcd1ca7..11535f073c807 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -35,7 +35,6 @@
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
-#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
@@ -2042,13 +2041,13 @@ struct ExtractValueOpConversion
 /// InsertValue is the generalized instruction for the composition of new
 /// aggregate type values.
 struct InsertValueOpConversion
-    : public fir::FIROpAndTypeConversion<fir::InsertValueOp>,
+    : public mlir::OpConversionPattern<fir::InsertValueOp>,
       public ValueOpCommon {
-  using FIROpAndTypeConversion::FIROpAndTypeConversion;
+  using OpConversionPattern::OpConversionPattern;
 
   llvm::LogicalResult
-  doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
-            mlir::ConversionPatternRewriter &rewriter) const override {
+  matchAndRewrite(fir::InsertValueOp insertVal, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::ValueRange operands = adaptor.getOperands();
     auto indices = collectIndices(rewriter, insertVal.getCoor());
     toRowMajor(indices, operands[0].getType());
@@ -2669,8 +2668,9 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
 };
 
 /// Lower `fir.has_value` operation to `llvm.return` operation.
-struct HasValueOpConversion : public fir::FIROpConversion<fir::HasValueOp> {
-  using FIROpConversion::FIROpConversion;
+struct HasValueOpConversion
+    : public mlir::OpConversionPattern<fir::HasValueOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   llvm::LogicalResult
   matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
@@ -3515,29 +3515,6 @@ struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> {
   }
 };
 
-struct UnrealizedConversionCastOpConversion
-    : public fir::FIROpConversion<mlir::UnrealizedConversionCastOp> {
-  using FIROpConversion::FIROpConversion;
-
-  llvm::LogicalResult
-  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
-    assert(op.getOutputs().getTypes().size() == 1 && "expect a single type");
-    mlir::Type convertedType = convertType(op.getOutputs().getTypes()[0]);
-    if (convertedType == adaptor.getInputs().getTypes()[0]) {
-      rewriter.replaceOp(op, adaptor.getInputs());
-      return mlir::success();
-    }
-
-    convertedType = adaptor.getInputs().getTypes()[0];
-    if (convertedType == op.getOutputs().getType()[0]) {
-      rewriter.replaceOp(op, adaptor.getInputs());
-      return mlir::success();
-    }
-    return mlir::failure();
-  }
-};
-
 struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
   using MustBeDeadConversion::MustBeDeadConversion;
 };
@@ -3714,7 +3691,8 @@ class FIRToLLVMLowering
       signalPassFailure();
     }
 
-    // Run pass to add comdats to functions that have weak linkage on relevant platforms
+    // Run pass to add comdats to functions that have weak linkage on relevant
+    // platforms
     if (fir::getTargetTriple(mod).supportsCOMDAT()) {
       mlir::OpPassManager comdatPM("builtin.module");
       comdatPM.addPass(mlir::LLVM::createLLVMAddComdats());
@@ -3789,16 +3767,19 @@ void fir::populateFIRToLLVMConversionPatterns(
       DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
       EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
       FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
-      GlobalOpConversion, HasValueOpConversion, InsertOnRangeOpConversion,
-      InsertValueOpConversion, IsPresentOpConversion, LenParamIndexOpConversion,
-      LoadOpConversion, MulcOpConversion, NegcOpConversion,
-      NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
-      SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
-      ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
-      StoreOpConversion, StringLitOpConversion, SubcOpConversion,
-      TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
-      UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
-      UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
-      XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
-                                                                options);
+      GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
+      LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
+      NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
+      SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
+      ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
+      SliceOpConversion, StoreOpConversion, StringLitOpConversion,
+      SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
+      UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
+      UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
+      XReboxOpConversion, ZeroOpConversion>(converter, options);
+
+  // Patterns that are populated without a type converter do not trigger
+  // target materializations for the operands of the root op.
+  patterns.insert<HasValueOpConversion, InsertValueOpConversion>(
+      patterns.getContext());
 }
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index ce86c625e082f..7b46a1a92142b 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -122,40 +122,6 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
     // Convert it here to i1 just in case it survives.
     return mlir::IntegerType::get(&getContext(), 1);
   });
-  // FIXME: https://reviews.llvm.org/D82831 introduced an automatic
-  // materialization of conversion around function calls that is not working
-  // well with fir lowering to llvm (incorrect llvm.mlir.cast are inserted).
-  // Workaround until better analysis: register a handler that does not insert
-  // any conversions.
-  addSourceMaterialization(
-      [&](mlir::OpBuilder &builder, mlir::Type resultType,
-          mlir::ValueRange inputs,
-          mlir::Location loc) -> std::optional<mlir::Value> {
-        if (inputs.size() != 1)
-          return std::nullopt;
-        return inputs[0];
-      });
-  // Similar FIXME workaround here (needed for compare.fir/select-type.fir
-  // as well as rebox-global.fir tests). This is needed to cope with the
-  // the fact that codegen does not lower some operation results to the LLVM
-  // type produced by this LLVMTypeConverter. For instance, inside FIR
-  // globals, fir.box are lowered to llvm.struct, while the fir.box type
-  // conversion translates it into an llvm.ptr<llvm.struct<>> because
-  // descriptors are manipulated in memory outside of global initializers
-  // where this is not possible. Hence, MLIR inserts
-  // builtin.unrealized_conversion_cast after the translation of operations
-  // producing fir.box in fir.global codegen. addSourceMaterialization and
-  // addTargetMaterialization allow ignoring these ops and removing them
-  // after codegen assuming the type discrepencies are intended (like for
-  // fir.box inside globals).
-  addTargetMaterialization(
-      [&](mlir::OpBuilder &builder, mlir::Type resultType,
-          mlir::ValueRange inputs,
-          mlir::Location loc) -> std::optional<mlir::Value> {
-        if (inputs.size() != 1)
-          return std::nullopt;
-        return inputs[0];
-      });
 }
 
 // i32 is used here because LLVM wants i32 constants when indexing into struct
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 7bbfd709b0aaf..dda4f32872fef 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -119,4 +119,5 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
 // PASSES-NEXT: TargetRewrite
 // PASSES-NEXT: FIRToLLVMLowering
+// PASSES-NEXT: ReconcileUnrealizedCasts
 // PASSES-NEXT: LLVMIRLoweringPass

>From d8a0ebeda89b3fda310447616449db1cffd017f4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 6 Jul 2024 14:28:41 +0200
Subject: [PATCH 2/3] fix test

---
 mlir/docs/DialectConversion.md                |  3 +-
 .../SCF/TransformOps/SCFTransformOps.td       | 11 +++
 .../mlir/Transforms/DialectConversion.h       | 10 +--
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 28 ++++--
 .../Dialect/SCF/TransformOps/CMakeLists.txt   |  1 +
 .../SCF/TransformOps/SCFTransformOps.cpp      | 13 ++-
 .../Transforms/Utils/DialectConversion.cpp    | 86 +++++++++----------
 .../FuncToLLVM/func-memref-return.mlir        |  4 +-
 .../Transforms/test-block-legalization.mlir   | 44 ++++++++++
 9 files changed, 140 insertions(+), 60 deletions(-)
 create mode 100644 mlir/test/Transforms/test-block-legalization.mlir

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index db26e6477d5fc..23e74470a835f 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -352,7 +352,8 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value.
+  /// a signature conversion of a single block argument, to a single SSA value
+  /// with the old argument type.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 7bf914f6456ce..20880d94a83ca 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.scf.scf_to_control_flow",
+    [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns that lower structured control flow ops to unstructured
+    control flow.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
 
 def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a22f198bdf252..a51b00271f0ae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -174,15 +174,15 @@ class TypeConverter {
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
-  /// must return a Value of the converted type on success, an `std::nullopt` if
+  /// must return a Value of the type `T` on success, an `std::nullopt` if
   /// it failed but other materialization can be attempted, and `nullptr` on
-  /// unrecoverable failure. It will only be called for (sub)types of `T`.
-  /// Materialization functions must be provided when a type conversion may
-  /// persist after the conversion has finished.
+  /// unrecoverable failure. Materialization functions must be provided when a
+  /// type conversion may persist after the conversion has finished.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value.
+  /// a signature conversion of a single block argument, to a single SSA value
+  /// with the old block argument type.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index f5620a6a7cd91..32d02d5e438bd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
-  // Materialization for memrefs creates descriptor structs from individual
-  // values constituting them, when descriptors are used, i.e. more than one
-  // value represents a memref.
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type. The dialect conversion framework will then
+  // insert a target materialization from the original block argument type to
+  // a legal type.
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
           // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
         }
-        return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
-                                              inputs);
+        Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
+                                                    resultType, inputs);
+        // An argument materialization must return a value of type
+        // `resultType`, so insert a cast from the memref descriptor type
+        // (!llvm.struct) to the original memref type.
+        return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+            .getResult(0);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
+    Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
       // blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       if (!block->isEntryBlock() ||
           !isa<FunctionOpInterface>(block->getParentOp()))
         return std::nullopt;
-      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
+    } else {
+      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     }
-    return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    // An argument materialization must return a value of type `resultType`,
+    // so insert a cast from the memref descriptor type (!llvm.struct) to the
+    // original memref type.
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+        .getResult(0);
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
index 1d6f9ebd153f0..06bccab80e7d8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSCFDialect
+  MLIRSCFToControlFlow
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRTransformDialect
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 56ff2709a589e..c4a55c302d0a3 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
                                                  conversionTarget);
 }
 
+void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  populateSCFToControlFlowConversionPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // ForallToForOp
 //===----------------------------------------------------------------------===//
@@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
     return 1;
   };
 
-  std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubConstant =
+      getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> lbConstant =
+      getConstantIntValue(forOp.getLowerBound());
   DenseMap<Operation *, unsigned> opCycles;
   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
   for (Operation &op : forOp.getBody()->getOperations()) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2ab2949..4a3c11f398e7d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target,
-      Type origOutputType = nullptr)
+      MaterializationKind kind = MaterializationKind::Target)
       : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+        converterAndKind(converter, kind) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return converterAndKind.getInt();
   }
 
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
   llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
       converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
 };
 } // namespace
 
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Block *insertBlock,
                                        Block::iterator insertPt, Location loc,
                                        ValueRange inputs, Type outputType,
-                                       Type origOutputType,
                                        const TypeConverter *converter);
 
   Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
                                                ValueRange inputs,
-                                               Type origOutputType,
                                                Type outputType,
                                                const TypeConverter *converter);
 
@@ -1388,20 +1379,27 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (replArgs.size() == 1 &&
         (!converter || replArgs[0].getType() == origArg.getType())) {
       newArg = replArgs.front();
+      mapping.map(origArg, newArg);
     } else {
-      Type origOutputType = origArg.getType();
-
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
-
-      newArg = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter);
+      // Build argument materialization: new block arguments -> old block
+      // argument type.
+      Value argMat = buildUnresolvedArgumentMaterialization(
+          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
+      mapping.map(origArg, argMat);
+
+      // Build target materialization: old block argument type -> legal type.
+      // Note: This function returns an "empty" type if no valid conversion to
+      // a legal type exists. In that case, we continue the conversion with the
+      // original block argument type.
+      if (Type legalOutputType = converter->convertType(origArg.getType())) {
+        newArg = buildUnresolvedTargetMaterialization(
+            origArg.getLoc(), argMat, legalOutputType, converter);
+        mapping.map(argMat, newArg);
+      } else {
+        newArg = argMat;
+      }
     }
 
-    mapping.map(origArg, newArg);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
     argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
@@ -1424,7 +1422,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    Location loc, ValueRange inputs, Type outputType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1433,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   OpBuilder builder(insertBlock, insertPt);
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  origOutputType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type origOutputType,
-    Type outputType, const TypeConverter *converter) {
+    Block *block, Location loc, ValueRange inputs, Type outputType,
+    const TypeConverter *converter) {
   return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
                                         block->begin(), loc, inputs, outputType,
-                                        origOutputType, converter);
+                                        converter);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
@@ -1456,7 +1453,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
 
   return buildUnresolvedMaterialization(MaterializationKind::Target,
                                         insertBlock, insertPt, loc, input,
-                                        outputType, outputType, converter);
+                                        outputType, converter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2672,6 +2669,9 @@ static void computeNecessaryMaterializations(
     ConversionPatternRewriterImpl &rewriterImpl,
     DenseMap<Value, SmallVector<Value>> &inverseMapping,
     SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
+  // Helper function to check if the given value or a not yet materialized
+  // replacement of the given value is live.
+  // Note: `inverseMapping` maps from replaced values to original values.
   auto isLive = [&](Value value) {
     auto findFn = [&](Operation *user) {
       auto matIt = materializationOps.find(user);
@@ -2679,12 +2679,18 @@ static void computeNecessaryMaterializations(
         return !necessaryMaterializations.count(matIt->second);
       return rewriterImpl.isOpIgnored(user);
     };
-    // This value may be replacing another value that has a live user.
-    for (Value inv : inverseMapping.lookup(value))
-      if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
+    // A worklist is needed because a value may have gone through a chain of
+    // replacements and each of the replaced values may have live users.
+    SmallVector<Value> worklist;
+    worklist.push_back(value);
+    while (!worklist.empty()) {
+      Value next = worklist.pop_back_val();
+      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
         return true;
-    // Or have live users itself.
-    return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
+      // This value may be replacing another value that has a live user.
+      llvm::append_range(worklist, inverseMapping.lookup(next));
+    }
+    return false;
   };
 
   llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2850,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
     switch (mat.getMaterializationKind()) {
     case MaterializationKind::Argument:
       // Try to materialize an argument conversion.
-      // FIXME: The current argument materialization hook expects the original
-      // output type, even though it doesn't use that as the actual output type
-      // of the generated IR. The output type is just used as an indicator of
-      // the type of materialization to do. This behavior is really awkward in
-      // that it diverges from the behavior of the other hooks, and can be
-      // easily misunderstood. We should clean up the argument hooks to better
-      // represent the desired invariants we actually care about.
       newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands);
       if (newMaterialization)
         break;
-
       // If an argument materialization failed, fallback to trying a target
       // materialization.
       [[fallthrough]];
@@ -2865,6 +2863,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
       break;
     }
     if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
       replaceMaterialization(rewriterImpl, opResult, newMaterialization,
                              inverseMapping);
       return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 91ef571cb3bf7..6b9df32fe02dd 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
 
-// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1'  %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
 
-// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
 
 // These tests were separated from func-memref.mlir because applying
 // -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
diff --git a/mlir/test/Transforms/test-block-legalization.mlir b/mlir/test/Transforms/test-block-legalization.mlir
new file mode 100644
index 0000000000000..d739f95a56947
--- /dev/null
+++ b/mlir/test/Transforms/test-block-legalization.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @complex_block_signature_conversion(
+//       CHECK:   %[[cst:.*]] = complex.constant
+//       CHECK:   %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
+// Note: Some blocks are omitted.
+//       CHECK:   llvm.br ^[[block1:.*]](%[[complex_llvm]]
+//       CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
+//       CHECK:   llvm.br ^[[block2:.*]]
+//       CHECK: ^[[block2]]:
+//       CHECK:   "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
+func.func @complex_block_signature_conversion() {
+  %cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
+  %true = arith.constant true
+  %0 = scf.if %true -> complex<f64> {
+    scf.yield %cst : complex<f64>
+  } else {
+    scf.yield %cst : complex<f64>
+  }
+
+  // Regression test to ensure that the a source materialization is inserted.
+  // The operand of "test.consumer_of_complex" must not change.
+  "test.consumer_of_complex"(%0) : (complex<f64>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %toplevel_module
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_conversion_patterns to %func {
+      transform.apply_conversion_patterns.dialect_to_llvm "cf"
+      transform.apply_conversion_patterns.func.func_to_llvm
+      transform.apply_conversion_patterns.scf.scf_to_control_flow
+    } with type_converter {
+      transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+    } {
+      legal_dialects = ["llvm"], 
+      partial_conversion
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From c254f78e2636698a499a6c6baf4eb656b8e15323 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 13 Jul 2024 17:36:37 +0200
Subject: [PATCH 3/3] [mlir][Transforms] Dialect conversion: Simplify handling
 of dropped arguments

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

Improve variable name: origOutputType -> origArgType. Add an assertion to check that this field is only used for argument materializations.
Add more comments to applySignatureConversion.
Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.
---
 .../Transforms/Utils/DialectConversion.cpp    | 191 +++++++-----------
 .../test-legalize-type-conversion.mlir        |   6 +-
 2 files changed, 70 insertions(+), 127 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4a3c11f398e7d..50e8e7cf670ed 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
 /// The type of materialization.
 enum MaterializationKind {
   /// This materialization materializes a conversion for an illegal block
-  /// argument type, to a legal one.
+  /// argument type, to the original one.
   Argument,
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 };
 } // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
       continue;
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1340,72 +1289,71 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
+
+    // Helper function that tries to legalize the given type. Returns the given
+    // type if it could not be legalized.
+    // FIXME: We simply pass through the replacement argument if there wasn't a
+    // converter, which isn't great as it allows implicit type conversions to
+    // appear. We should properly restructure this code to handle cases where a
+    // converter isn't provided and also to properly handle the case where an
+    // argument materialization is actually a temporary source materialization
+    // (e.g. in the case of 1->N).
+    auto tryLegalizeType = [&](Type type) {
+      if (converter)
+        if (Type t = converter->convertType(type))
+          return t;
+      return type;
+    };
+
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
-    // Otherwise, this is a 1->1+ mapping.
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
+    Value argMat = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+    mapping.map(origArg, argMat);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
-    // FIXME: We simply pass through the replacement argument if there wasn't a
-    // converter, which isn't great as it allows implicit type conversions to
-    // appear. We should properly restructure this code to handle cases where a
-    // converter isn't provided and also to properly handle the case where an
-    // argument materialization is actually a temporary source materialization
-    // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-      mapping.map(origArg, newArg);
-    } else {
-      // Build argument materialization: new block arguments -> old block
-      // argument type.
-      Value argMat = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
-      mapping.map(origArg, argMat);
-
-      // Build target materialization: old block argument type -> legal type.
-      // Note: This function returns an "empty" type if no valid conversion to
-      // a legal type exists. In that case, we continue the conversion with the
-      // original block argument type.
-      if (Type legalOutputType = converter->convertType(origArg.getType())) {
-        newArg = buildUnresolvedTargetMaterialization(
-            origArg.getLoc(), argMat, legalOutputType, converter);
-        mapping.map(argMat, newArg);
-      } else {
-        newArg = argMat;
-      }
+    Type legalOutputType = tryLegalizeType(origArgType);
+    if (legalOutputType != origArgType) {
+      Value targetMat = buildUnresolvedTargetMaterialization(
+          origArg.getLoc(), argMat, legalOutputType, converter);
+      mapping.map(argMat, targetMat);
     }
-
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1436,13 +1384,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
-                                        block->begin(), loc, inputs, outputType,
-                                        converter);
-}
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
     const TypeConverter *converter) {
@@ -2861,6 +2802,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
     }
     if (newMaterialization) {
       assert(newMaterialization.getType() == outputType &&
@@ -2873,8 +2818,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
-                               "from "
-                            << inputOperands.getTypes() << " to " << outputType
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
                             << " that remained live after conversion";
   if (Operation *liveUser = findLiveUser(op->getUsers())) {
     diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
 
 
 func.func @test_invalid_arg_materialization(
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
-  // expected-note at below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
 // Make sure argument type changes aren't implicitly forwarded.
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
-  // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+  // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
   ^bb0(%arg0: f32):
-    // expected-note at below {{see existing live user here}}
     "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()



More information about the Mlir-commits mailing list