[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Fix missing source materialization (PR #97903)

Matthias Springer llvmlistbot at llvm.org
Mon Jul 15 07:09:29 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/97903

>From cbbf7417717aff35e59d0403c1ec82aaa7fb8afc Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 6 Jul 2024 14:28:41 +0200
Subject: [PATCH] fix test

---
 mlir/docs/DialectConversion.md                |  3 +-
 .../SCF/TransformOps/SCFTransformOps.td       | 11 +++
 .../mlir/Transforms/DialectConversion.h       | 10 +--
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 28 ++++--
 .../Dialect/SCF/TransformOps/CMakeLists.txt   |  1 +
 .../SCF/TransformOps/SCFTransformOps.cpp      | 13 ++-
 .../Transforms/Utils/DialectConversion.cpp    | 87 ++++++++++---------
 .../FuncToLLVM/func-memref-return.mlir        |  4 +-
 .../Transforms/test-block-legalization.mlir   | 44 ++++++++++
 9 files changed, 141 insertions(+), 60 deletions(-)
 create mode 100644 mlir/test/Transforms/test-block-legalization.mlir

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index db26e6477d5fc..23e74470a835f 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -352,7 +352,8 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value.
+  /// a signature conversion of a single block argument, to a single SSA value
+  /// with the old argument type.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 7bf914f6456ce..20880d94a83ca 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.scf.scf_to_control_flow",
+    [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns that lower structured control flow ops to unstructured
+    control flow.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
 
 def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a22f198bdf252..a51b00271f0ae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -174,15 +174,15 @@ class TypeConverter {
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
-  /// must return a Value of the converted type on success, an `std::nullopt` if
+  /// must return a Value of the type `T` on success, an `std::nullopt` if
   /// it failed but other materialization can be attempted, and `nullptr` on
-  /// unrecoverable failure. It will only be called for (sub)types of `T`.
-  /// Materialization functions must be provided when a type conversion may
-  /// persist after the conversion has finished.
+  /// unrecoverable failure. Materialization functions must be provided when a
+  /// type conversion may persist after the conversion has finished.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value.
+  /// a signature conversion of a single block argument, to a single SSA value
+  /// with the old block argument type.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index f5620a6a7cd91..32d02d5e438bd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
-  // Materialization for memrefs creates descriptor structs from individual
-  // values constituting them, when descriptors are used, i.e. more than one
-  // value represents a memref.
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type. The dialect conversion framework will then
+  // insert a target materialization from the original block argument type to
+  // a legal type.
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
           // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
         }
-        return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
-                                              inputs);
+        Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
+                                                    resultType, inputs);
+        // An argument materialization must return a value of type
+        // `resultType`, so insert a cast from the memref descriptor type
+        // (!llvm.struct) to the original memref type.
+        return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+            .getResult(0);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
+    Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
       // blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       if (!block->isEntryBlock() ||
           !isa<FunctionOpInterface>(block->getParentOp()))
         return std::nullopt;
-      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
+    } else {
+      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     }
-    return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    // An argument materialization must return a value of type `resultType`,
+    // so insert a cast from the memref descriptor type (!llvm.struct) to the
+    // original memref type.
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+        .getResult(0);
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
index 1d6f9ebd153f0..06bccab80e7d8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSCFDialect
+  MLIRSCFToControlFlow
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRTransformDialect
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 56ff2709a589e..c4a55c302d0a3 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
                                                  conversionTarget);
 }
 
+void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  populateSCFToControlFlowConversionPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // ForallToForOp
 //===----------------------------------------------------------------------===//
@@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
     return 1;
   };
 
-  std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubConstant =
+      getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> lbConstant =
+      getConstantIntValue(forOp.getLowerBound());
   DenseMap<Operation *, unsigned> opCycles;
   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
   for (Operation &op : forOp.getBody()->getOperations()) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2ab2949..1e0afee2373a9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target,
-      Type origOutputType = nullptr)
+      MaterializationKind kind = MaterializationKind::Target)
       : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+        converterAndKind(converter, kind) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return converterAndKind.getInt();
   }
 
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
   llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
       converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
 };
 } // namespace
 
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Block *insertBlock,
                                        Block::iterator insertPt, Location loc,
                                        ValueRange inputs, Type outputType,
-                                       Type origOutputType,
                                        const TypeConverter *converter);
 
   Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
                                                ValueRange inputs,
-                                               Type origOutputType,
                                                Type outputType,
                                                const TypeConverter *converter);
 
@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (replArgs.size() == 1 &&
         (!converter || replArgs[0].getType() == origArg.getType())) {
       newArg = replArgs.front();
+      mapping.map(origArg, newArg);
     } else {
-      Type origOutputType = origArg.getType();
-
-      // Legalize the argument output type.
-      Type outputType = origOutputType;
-      if (Type legalOutputType = converter->convertType(outputType))
-        outputType = legalOutputType;
-
-      newArg = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter);
+      // Build argument materialization: new block arguments -> old block
+      // argument type.
+      Value argMat = buildUnresolvedArgumentMaterialization(
+          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
+      mapping.map(origArg, argMat);
+
+      // Build target materialization: old block argument type -> legal type.
+      // Note: This function returns an "empty" type if no valid conversion to
+      // a legal type exists. In that case, we continue the conversion with the
+      // original block argument type.
+      Type legalOutputType = converter->convertType(origArg.getType());
+      if (legalOutputType && legalOutputType != origArg.getType()) {
+        newArg = buildUnresolvedTargetMaterialization(
+            origArg.getLoc(), argMat, legalOutputType, converter);
+        mapping.map(argMat, newArg);
+      } else {
+        newArg = argMat;
+      }
     }
 
-    mapping.map(origArg, newArg);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
     argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
-    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    Location loc, ValueRange inputs, Type outputType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   OpBuilder builder(insertBlock, insertPt);
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  origOutputType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type origOutputType,
-    Type outputType, const TypeConverter *converter) {
+    Block *block, Location loc, ValueRange inputs, Type outputType,
+    const TypeConverter *converter) {
   return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
                                         block->begin(), loc, inputs, outputType,
-                                        origOutputType, converter);
+                                        converter);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
 
   return buildUnresolvedMaterialization(MaterializationKind::Target,
                                         insertBlock, insertPt, loc, input,
-                                        outputType, outputType, converter);
+                                        outputType, converter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2672,6 +2670,9 @@ static void computeNecessaryMaterializations(
     ConversionPatternRewriterImpl &rewriterImpl,
     DenseMap<Value, SmallVector<Value>> &inverseMapping,
     SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
+  // Helper function to check if the given value or a not yet materialized
+  // replacement of the given value is live.
+  // Note: `inverseMapping` maps from replaced values to original values.
   auto isLive = [&](Value value) {
     auto findFn = [&](Operation *user) {
       auto matIt = materializationOps.find(user);
@@ -2679,12 +2680,18 @@ static void computeNecessaryMaterializations(
         return !necessaryMaterializations.count(matIt->second);
       return rewriterImpl.isOpIgnored(user);
     };
-    // This value may be replacing another value that has a live user.
-    for (Value inv : inverseMapping.lookup(value))
-      if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
+    // A worklist is needed because a value may have gone through a chain of
+    // replacements and each of the replaced values may have live users.
+    SmallVector<Value> worklist;
+    worklist.push_back(value);
+    while (!worklist.empty()) {
+      Value next = worklist.pop_back_val();
+      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
         return true;
-    // Or have live users itself.
-    return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
+      // This value may be replacing another value that has a live user.
+      llvm::append_range(worklist, inverseMapping.lookup(next));
+    }
+    return false;
   };
 
   llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
     switch (mat.getMaterializationKind()) {
     case MaterializationKind::Argument:
       // Try to materialize an argument conversion.
-      // FIXME: The current argument materialization hook expects the original
-      // output type, even though it doesn't use that as the actual output type
-      // of the generated IR. The output type is just used as an indicator of
-      // the type of materialization to do. This behavior is really awkward in
-      // that it diverges from the behavior of the other hooks, and can be
-      // easily misunderstood. We should clean up the argument hooks to better
-      // represent the desired invariants we actually care about.
       newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands);
       if (newMaterialization)
         break;
-
       // If an argument materialization failed, fallback to trying a target
       // materialization.
       [[fallthrough]];
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
       break;
     }
     if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
       replaceMaterialization(rewriterImpl, opResult, newMaterialization,
                              inverseMapping);
       return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 91ef571cb3bf7..6b9df32fe02dd 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
 
-// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1'  %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
 
-// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
 
 // These tests were separated from func-memref.mlir because applying
 // -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
diff --git a/mlir/test/Transforms/test-block-legalization.mlir b/mlir/test/Transforms/test-block-legalization.mlir
new file mode 100644
index 0000000000000..d739f95a56947
--- /dev/null
+++ b/mlir/test/Transforms/test-block-legalization.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @complex_block_signature_conversion(
+//       CHECK:   %[[cst:.*]] = complex.constant
+//       CHECK:   %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
+// Note: Some blocks are omitted.
+//       CHECK:   llvm.br ^[[block1:.*]](%[[complex_llvm]]
+//       CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
+//       CHECK:   llvm.br ^[[block2:.*]]
+//       CHECK: ^[[block2]]:
+//       CHECK:   "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
+func.func @complex_block_signature_conversion() {
+  %cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
+  %true = arith.constant true
+  %0 = scf.if %true -> complex<f64> {
+    scf.yield %cst : complex<f64>
+  } else {
+    scf.yield %cst : complex<f64>
+  }
+
+  // Regression test to ensure that the a source materialization is inserted.
+  // The operand of "test.consumer_of_complex" must not change.
+  "test.consumer_of_complex"(%0) : (complex<f64>) -> ()
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %toplevel_module
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_conversion_patterns to %func {
+      transform.apply_conversion_patterns.dialect_to_llvm "cf"
+      transform.apply_conversion_patterns.func.func_to_llvm
+      transform.apply_conversion_patterns.scf.scf_to_control_flow
+    } with type_converter {
+      transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+    } {
+      legal_dialects = ["llvm"], 
+      partial_conversion
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list