[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (PR #117513)

Matthias Springer llvmlistbot at llvm.org
Mon Dec 23 04:10:47 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/117513

>From 759edf6f0dcc6bc96da75e93f8f75f628bbee35a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Do not build target
 mat. during 1:N replacement

fix test

experiement
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 130 ++++++++++++------
 .../Transforms/Utils/DialectConversion.cpp    |  46 ++-----
 mlir/test/Transforms/test-legalizer.mlir      |   8 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  47 +++----
 4 files changed, 128 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+
   // Helper function that checks if the given value range is a bare pointer.
   auto isBarePointer = [](ValueRange values) {
     return values.size() == 1 &&
            isa<LLVM::LLVMPointerType>(values.front().getType());
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packUnrankedMemRefDesc =
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc, LLVMTypeConverter &converter) -> Value {
     // Note: Bare pointers are not supported for unranked memrefs because a
     // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
       return Value();
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                          inputs);
+  };
+
+  // MemRef descriptor elements -> UnrankedMemRefType
+  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+                                          UnrankedMemRefType resultType,
+                                          ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type
     // `resultType`, so insert a cast from the memref descriptor type
     // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (isBarePointer(inputs)) {
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else if (TypeRange(inputs) ==
-               getMemRefDescriptorFields(resultType,
-                                         /*unpackAggregates=*/true)) {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    } else {
-      // The inputs are neither a bare pointer nor an unpacked memref
-      // descriptor. This materialization function cannot be used.
+    Value packed =
+        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-    }
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+        .getResult(0);
+  };
+
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  LLVMTypeConverter &converter) -> Value {
+    assert(resultType && "expected non-null result type");
+    if (isBarePointer(inputs))
+      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                               resultType, inputs[0]);
+    if (TypeRange(inputs) ==
+        converter.getMemRefDescriptorFields(resultType,
+                                            /*unpackAggregates=*/true))
+      return MemRefDescriptor::pack(builder, loc, converter, resultType,
+                                    inputs);
+    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+    // This materialization function cannot be used.
+    return Value();
+  };
+
+  // MemRef descriptor elements -> MemRefType
+  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  // Add generic source and target materializations to handle cases where
-  // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+    Value packed =
+        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
         .getResult(0);
-  });
+  };
+
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type.
+  addArgumentMaterialization(unrakedMemRefMaterialization);
+  addArgumentMaterialization(rankedMemRefMaterialization);
+  addSourceMaterialization(unrakedMemRefMaterialization);
+  addSourceMaterialization(rankedMemRefMaterialization);
+
+  // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
+    // The original MemRef type is required to build a MemRef descriptor
+    // because the sizes/strides of the MemRef cannot be inferred from just the
+    // bare pointer.
+    if (!originalType)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+    if (resultType != convertType(originalType))
+      return Value();
+    if (auto memrefType = dyn_cast<MemRefType>(originalType))
+      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+                                    *this);
+    return Value();
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// function will be deleted when full 1:N support has been added.
   ///
   /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
+  /// type.
   void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    if (replArgs.size() == 1) {
+      mapping.map(origArg, replArgs.front());
+    } else {
+      insertNTo1Materialization(
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
   UnrealizedConversionCastOp argCastOp;
-  Value argMat = buildUnresolvedMaterialization(
+  buildUnresolvedMaterialization(
       MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
-      &argCastOp);
+      /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
   if (argCastOp)
     nTo1TempMaterializations.insert(argCastOp);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    UnrealizedConversionCastOp targetCastOp;
-    buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*valueToMap=*/argMat, /*inputs=*/argMat,
-        /*outputType=*/legalOutputType, /*originalType=*/originalType,
-        converter, &targetCastOp);
-    if (targetCastOp)
-      nTo1TempMaterializations.insert(targetCastOp);
-  }
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
+  assert(this && "expected non-null type converter");
+  assert(t && "expected non-null type");
+
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
-    ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
-      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+    ^bb0(%i0: f64, %unused: i16, %i1: f64):
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (f64, f64) -> ()
   }) : () -> ()
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ce2820b80a945d..a470497fdbb560 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
-  TestPassthroughInvalidOp(MLIRContext *ctx)
-      : ConversionPattern("test.invalid", 1, ctx) {}
+  TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.invalid", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns.add<
-        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
-        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
-        TestUpdateConsumerType, TestNonRootReplacement,
-        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
-        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp,
-        TestRepetitive1ToNConsumer>(&getContext());
-    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
-        &getContext(), converter);
+    patterns
+        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+             TestNonRootReplacement, TestBoundedRecursiveRewrite,
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification, TestEraseOp,
+             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+                 TestPassthroughInvalidOp>(&getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer
 };
 
 struct TestReplaceWithLegalOp : public ConversionPattern {
-  TestReplaceWithLegalOp(MLIRContext *ctx)
-      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, "test.replace_with_legal_op",
+                          /*benefit=*/1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
-                 TestSignatureConversionUndo,
-                 TestTestSignatureConversionNoConverter>(converter,
-                                                         &getContext());
-    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
-        &getContext());
+    patterns
+        .add<TestTypeConsumerForward, TestTypeConversionProducer,
+             TestSignatureConversionUndo,
+             TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+            converter, &getContext());
+    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 



More information about the Mlir-commits mailing list