[Mlir-commits] [mlir] 3cc311a - [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (#117513)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 23 04:27:43 PST 2024


Author: Matthias Springer
Date: 2024-12-23T13:27:39+01:00
New Revision: 3cc311ab8674eab6b9101cdf3823b55ea23d6535

URL: https://github.com/llvm/llvm-project/commit/3cc311ab8674eab6b9101cdf3823b55ea23d6535
DIFF: https://github.com/llvm/llvm-project/commit/3cc311ab8674eab6b9101cdf3823b55ea23d6535.diff

LOG: [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (#117513)

During a 1:N replacement (`applySignatureConversion` or
`replaceOpWithMultiple`), the dialect conversion driver used to insert
two materializations:

* Argument materialization: convert N replacement values to 1 SSA value
of the original type `S`.
* Target materialization: convert original type to legalized type `T`.

The target materialization is unnecessary. Subsequent patterns receive
the replacement values via their adaptors. These patterns have their own
type converter. When they see a replacement value of type `S`, they will
automatically insert a target materialization to type `T`. There is no
reason to do this already during the 1:N replacement. (The functionality
used to be duplicated in `remapValues` and `insertNTo1Materialization`.)

Special case: If a subsequent pattern does not have a type converter, it
does *not* insert any target materializations. That's because the
absence of a type converter indicates that the pattern does not care
about type legality. Therefore, it is correct to pass an SSA value of
type `S` (or any other type) to the pattern.

Note: Most patterns in `TestPatterns.cpp` run without a type converter.
To make sure that the tests still behave the same, some of these
patterns now have a type converter.

This commit is in preparation of adding 1:N support to the conversion
value mapping. Before making any further changes to the mapping
infrastructure, I'd like to make sure that the code base around it (that
uses the mapping) is robust.

Added: 
    

Modified: 
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+
   // Helper function that checks if the given value range is a bare pointer.
   auto isBarePointer = [](ValueRange values) {
     return values.size() == 1 &&
            isa<LLVM::LLVMPointerType>(values.front().getType());
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packUnrankedMemRefDesc =
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc, LLVMTypeConverter &converter) -> Value {
     // Note: Bare pointers are not supported for unranked memrefs because a
     // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
       return Value();
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                          inputs);
+  };
+
+  // MemRef descriptor elements -> UnrankedMemRefType
+  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+                                          UnrankedMemRefType resultType,
+                                          ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type
     // `resultType`, so insert a cast from the memref descriptor type
     // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (isBarePointer(inputs)) {
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else if (TypeRange(inputs) ==
-               getMemRefDescriptorFields(resultType,
-                                         /*unpackAggregates=*/true)) {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    } else {
-      // The inputs are neither a bare pointer nor an unpacked memref
-      // descriptor. This materialization function cannot be used.
+    Value packed =
+        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-    }
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+        .getResult(0);
+  };
+
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  LLVMTypeConverter &converter) -> Value {
+    assert(resultType && "expected non-null result type");
+    if (isBarePointer(inputs))
+      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                               resultType, inputs[0]);
+    if (TypeRange(inputs) ==
+        converter.getMemRefDescriptorFields(resultType,
+                                            /*unpackAggregates=*/true))
+      return MemRefDescriptor::pack(builder, loc, converter, resultType,
+                                    inputs);
+    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+    // This materialization function cannot be used.
+    return Value();
+  };
+
+  // MemRef descriptor elements -> MemRefType
+  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  // Add generic source and target materializations to handle cases where
-  // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+    Value packed =
+        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
         .getResult(0);
-  });
+  };
+
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type.
+  addArgumentMaterialization(unrakedMemRefMaterialization);
+  addArgumentMaterialization(rankedMemRefMaterialization);
+  addSourceMaterialization(unrakedMemRefMaterialization);
+  addSourceMaterialization(rankedMemRefMaterialization);
+
+  // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
+    // The original MemRef type is required to build a MemRef descriptor
+    // because the sizes/strides of the MemRef cannot be inferred from just the
+    // bare pointer.
+    if (!originalType)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+    if (resultType != convertType(originalType))
+      return Value();
+    if (auto memrefType = dyn_cast<MemRefType>(originalType))
+      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+                                    *this);
+    return Value();
   });
 
   // Integer memory spaces map to themselves.

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// function will be deleted when full 1:N support has been added.
   ///
   /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
+  /// type.
   void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    if (replArgs.size() == 1) {
+      mapping.map(origArg, replArgs.front());
+    } else {
+      insertNTo1Materialization(
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
   UnrealizedConversionCastOp argCastOp;
-  Value argMat = buildUnresolvedMaterialization(
+  buildUnresolvedMaterialization(
       MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
-      &argCastOp);
+      /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
   if (argCastOp)
     nTo1TempMaterializations.insert(argCastOp);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    UnrealizedConversionCastOp targetCastOp;
-    buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*valueToMap=*/argMat, /*inputs=*/argMat,
-        /*outputType=*/legalOutputType, /*originalType=*/originalType,
-        converter, &targetCastOp);
-    if (targetCastOp)
-      nTo1TempMaterializations.insert(targetCastOp);
-  }
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
+  assert(this && "expected non-null type converter");
+  assert(t && "expected non-null type");
+
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
-    ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
-      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+    ^bb0(%i0: f64, %unused: i16, %i1: f64):
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (f64, f64) -> ()
   }) : () -> ()
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ce2820b80a945d..a470497fdbb560 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
-  TestPassthroughInvalidOp(MLIRContext *ctx)
-      : ConversionPattern("test.invalid", 1, ctx) {}
+  TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.invalid", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns.add<
-        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
-        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
-        TestUpdateConsumerType, TestNonRootReplacement,
-        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
-        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp,
-        TestRepetitive1ToNConsumer>(&getContext());
-    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
-        &getContext(), converter);
+    patterns
+        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+             TestNonRootReplacement, TestBoundedRecursiveRewrite,
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification, TestEraseOp,
+             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+                 TestPassthroughInvalidOp>(&getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer
 };
 
 struct TestReplaceWithLegalOp : public ConversionPattern {
-  TestReplaceWithLegalOp(MLIRContext *ctx)
-      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, "test.replace_with_legal_op",
+                          /*benefit=*/1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
-                 TestSignatureConversionUndo,
-                 TestTestSignatureConversionNoConverter>(converter,
-                                                         &getContext());
-    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
-        &getContext());
+    patterns
+        .add<TestTypeConsumerForward, TestTypeConversionProducer,
+             TestSignatureConversionUndo,
+             TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+            converter, &getContext());
+    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 


        


More information about the Mlir-commits mailing list