[Mlir-commits] [mlir] bcfdb3e - [mlir][transform] Add apply_conversion_patterns op

Matthias Springer llvmlistbot at llvm.org
Sun Aug 6 23:57:53 PDT 2023


Author: Matthias Springer
Date: 2023-08-07T08:49:55+02:00
New Revision: bcfdb3e4bc819c50c32c61070c5a1a86df808e49

URL: https://github.com/llvm/llvm-project/commit/bcfdb3e4bc819c50c32c61070c5a1a86df808e49
DIFF: https://github.com/llvm/llvm-project/commit/bcfdb3e4bc819c50c32c61070c5a1a86df808e49.diff

LOG: [mlir][transform] Add apply_conversion_patterns op

This transform op applies a dialect conversion to the targeted ops. Its design is similar to `apply_patterns`.

Patterns are specified in the first region of `apply_conversion_patterns`. They must implement the `ConversionPatternDescriptorOpInterface`. Regular rewrite patterns and dialect conversion patterns should not be mixed, so the interface is separate from the `PatternDescriptorOpInterface`.

The type converter is specified as the single op of the second region. It is optional; if no type converter is specified, it is expected that pattern descriptors provide their own type converters. If both the pattern descriptors and the `apply_conversion_patterns` op specify a type converter, the type converter of the pattern descriptor is used.

Differential Revision: https://reviews.llvm.org/D157109

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-pattern-application.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 472f642cfa05cf..114d79555dcef5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 namespace transform {

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index eaab05766fc455..5ce21be223bcb2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -240,9 +240,13 @@ def FindPayloadReplacementOpInterface
 
 def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   let description = [{
-    This interface should be implemented by ops that select patterns of a
-    `transform.apply_patterns` op. It provides a method to populate a rewrite
+    This interface should be implemented by ops that select rewrite patterns of
+    a `transform.apply_patterns` op. It provides a method to populate a rewrite
     pattern set with patterns.
+
+    Note: Conversion patterns are rewrite patterns in MLIR, but they should not
+    be populated with `PatternDescriptorOpInterface` because they cannot be
+    used in a greedy pattern rewrite.
   }];
 
   let cppNamespace = "::mlir::transform";
@@ -250,11 +254,73 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Populate patterns into the given pattern set.
+        Populate rewrite patterns into the given pattern set.
       }],
       /*returnType=*/"void",
       /*name=*/"populatePatterns",
-      /*arguments=*/(ins "RewritePatternSet &":$patterns)
+      /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns)
+    >,
+  ];
+}
+
+def ConversionPatternDescriptorOpInterface
+    : OpInterface<"ConversionPatternDescriptorOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that select conversion patterns
+    of a `transform.apply_patterns` op. It provides a method to populate a
+    rewrite pattern set with conversion patterns.
+
+    Note: Non-conversion rewrite patterns should not be populated with
+    `ConversionPatternDescriptorOpInterface` because it is not generally safe
+    to use non-conversion rewrite patterns as part of a dialect conversion.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate conversion patterns into the given pattern set with the
+        given type converter.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populatePatterns",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter,
+                         "::mlir::RewritePatternSet &":$patterns)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the type converter to be used with this pattern set. If no
+        type converter is specified, the default type converter of the enclosing
+        "apply_conversion_patterns" op is used.
+      }],
+      /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+      /*name=*/"getTypeConverter",
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return nullptr;"
+    >,
+  ];
+}
+
+def TypeConverterBuilderOpInterface
+    : OpInterface<"TypeConverterBuilderOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that specify a type converter
+    for a dialect conversion. Such ops can be used with
+    "apply_conversion_patterns".
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the type converter to be used with a dialect conversion.
+      }],
+      /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+      /*name=*/"getTypeConverter",
+      /*arguments=*/(ins)
     >,
   ];
 }

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 8a30205ee17680..c41c6d7768ac7f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -156,6 +156,84 @@ def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
   }];
 }
 
+def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]
+        # GraphRegionNoTerminator.traits> {
+  let summary = "Applies conversion patterns to the body of the targeted op";
+  let description = [{
+    This transform applies the specified conversion patterns to the targeted op
+    and all nested ops. By default, this transform applies a "full" dialect
+    conversion. If the `partial_conversion` unit attribute is present, this
+    transform applies a partial dialect conversion.
+
+    The patterns that should be applied are specified in the first graph region
+    of this op. They must implement the
+    `ConversionPatternDescriptorOpInterface`. The order in which patterns are
+    applied is unspecified; i.e., the ordering of ops in the region of this op
+    is irrelevant.
+
+    The second, optional graph region contains exactly one op that specifies
+    default type converter that should be used with this dialect conversion. If
+    provided, this op must implement the `TypeConverterBuilderOpInterface`.
+    Type converters are a property of conversion patterns: each conversion
+    pattern stores the type converter that should be used in its C++ class. Each
+    conversion pattern descriptor can optionally specify a type converter in its
+    `getTypeConverter` interface method. If no type converter is specified in
+    this method, the default type converter of the dialect conversion is used.
+    Default type converters are useful if the same type converter should be used
+    for multiple sets of conversion patterns. (Patterns that should not use this
+    default type converter specify their own type converter.)
+
+    The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
+    attributes specify the conversion target. At least one of those four
+    attributes must be specified.
+
+    This transform consumes the `target` handle and modifies the payload. It
+    does not produce any handles.
+
+    This transform fails silently if the dialect conversion was unsuccessful.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       OptionalAttr<StrArrayAttr>:$legal_ops,
+                       OptionalAttr<StrArrayAttr>:$illegal_ops,
+                       OptionalAttr<StrArrayAttr>:$legal_dialects,
+                       OptionalAttr<StrArrayAttr>:$illegal_dialects,
+                       UnitAttr:$partialConversion);
+  let results = (outs);
+  let regions = (region VariadicRegion<MaxSizedRegion<1>>:$regions);
+
+  let assemblyFormat = [{
+    `to` $target $regions attr-dict `:` type($target)
+  }];
+  let hasVerifier = 1;
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins
+        "Value":$target,
+        CArg<"function_ref<void(OpBuilder &, Location)>", "nullptr">:
+            $patternsBodyBuilder,
+        CArg<"function_ref<void(OpBuilder &, Location)>", "nullptr">:
+            $typeConverterBodyBuilder)>,
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::Region &getPatterns() {
+      return getRegion(0);
+    }
+
+    ::mlir::transform::TypeConverterBuilderOpInterface getDefaultTypeConverter() {
+      if (getNumRegions() < 2)
+        return {};
+      return ::llvm::cast<::mlir::transform::TypeConverterBuilderOpInterface>(
+          &getRegion(1).front().front());
+    }
+  }];
+}
+
 def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index d7205ec02690e8..32c56e903268f7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -28,10 +28,15 @@ void transform::detail::checkImplementsTransformOpInterface(
       *RegisteredOperationName::lookup(name, context);
   assert((opName.hasInterface<TransformOpInterface>() ||
           opName.hasInterface<PatternDescriptorOpInterface>() ||
+          opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
+          opName.hasInterface<TypeConverterBuilderOpInterface>() ||
           opName.hasTrait<OpTrait::IsTerminator>()) &&
          "non-terminator ops injected into the transform dialect must "
-         "implement TransformOpInterface or PatternDescriptorOpInterface");
-  if (!opName.hasInterface<PatternDescriptorOpInterface>()) {
+         "implement TransformOpInterface or PatternDescriptorOpInterface or "
+         "ConversionPatternDescriptorOpInterface");
+  if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
+      !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
+      !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
     assert(opName.hasInterface<MemoryEffectOpInterface>() &&
            "ops injected into the transform dialect must implement "
            "MemoryEffectsOpInterface");

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 5327a5f7f2524d..e6baf470199837 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/CSE.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 #include "llvm/ADT/STLExtras.h"
@@ -478,6 +479,159 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
     op.getCanonicalizationPatterns(patterns, ctx);
 }
 
+//===----------------------------------------------------------------------===//
+// ApplyConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  MLIRContext *ctx = getContext();
+
+  // Default type converter is built on demand.
+  std::unique_ptr<TypeConverter> defaultTypeConverter;
+
+  // Configure conversion target.
+  ConversionTarget conversionTarget(*ctx);
+  if (getLegalOps())
+    for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
+      conversionTarget.addLegalOp(
+          OperationName(cast<StringAttr>(attr).getValue(), ctx));
+  if (getIllegalOps())
+    for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
+      conversionTarget.addIllegalOp(
+          OperationName(cast<StringAttr>(attr).getValue(), ctx));
+  if (getLegalDialects())
+    for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
+      conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
+  if (getIllegalDialects())
+    for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
+      conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
+
+  // Gather all specified patterns.
+  RewritePatternSet patterns(ctx);
+  if (!getPatterns().empty()) {
+    for (Operation &op : getPatterns().front()) {
+      auto descriptor =
+          cast<transform::ConversionPatternDescriptorOpInterface>(&op);
+
+      // Check if this pattern set specifies a type converter.
+      std::unique_ptr<TypeConverter> typeConverter =
+          descriptor.getTypeConverter();
+      TypeConverter *converter = nullptr;
+      if (typeConverter) {
+        converter = typeConverter.get();
+      } else {
+        // No type converter specified: Use the default type converter.
+        if (!defaultTypeConverter) {
+          // Instantiate the default type converter.
+          transform::TypeConverterBuilderOpInterface typeConverterBuilder =
+              getDefaultTypeConverter();
+          if (!typeConverterBuilder) {
+            auto diag = emitDefiniteFailure()
+                        << "pattern descriptor does not specify type "
+                           "converter and apply_conversion_patterns op has "
+                           "no default type converter";
+            diag.attachNote(op.getLoc()) << "pattern descriptor op";
+            return diag;
+          }
+          defaultTypeConverter = typeConverterBuilder.getTypeConverter();
+          assert(defaultTypeConverter && "expected type converter");
+        }
+        converter = defaultTypeConverter.get();
+      }
+      descriptor.populatePatterns(*converter, patterns);
+    }
+  }
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    // Make sure that this transform is not applied to itself. Modifying the
+    // transform IR while it is being interpreted is generally dangerous.
+    DiagnosedSilenceableFailure payloadCheck =
+        ensurePayloadIsSeparateFromTransform(*this, target);
+    if (!payloadCheck.succeeded())
+      return payloadCheck;
+
+    LogicalResult status = failure();
+    if (getPartialConversion()) {
+      status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+    } else {
+      status = applyFullConversion(target, conversionTarget, frozenPatterns);
+    }
+
+    if (failed(status)) {
+      auto diag = emitSilenceableError() << "dialect conversion failed";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ApplyConversionPatternsOp::verify() {
+  if (getNumRegions() != 1 && getNumRegions() != 2)
+    return emitOpError() << "expected 1 or 2 regions";
+  if (!getPatterns().empty()) {
+    for (Operation &op : getPatterns().front()) {
+      if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
+        InFlightDiagnostic diag =
+            emitOpError() << "expected pattern children ops to implement "
+                             "ConversionPatternDescriptorOpInterface";
+        diag.attachNote(op.getLoc()) << "op without interface";
+        return diag;
+      }
+    }
+  }
+  if (getNumRegions() == 2) {
+    Region &typeConverterRegion = getRegion(1);
+    if (!llvm::hasSingleElement(typeConverterRegion.front()))
+      return emitOpError()
+             << "expected exactly one op in default type converter region";
+    Operation *typeConverterOp = &typeConverterRegion.front().front();
+    if (!isa<transform::TypeConverterBuilderOpInterface>(typeConverterOp)) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "expected default converter child op to "
+                                   "implement TypeConverterBuilderOpInterface";
+      diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
+      return diag;
+    }
+  }
+  if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() &&
+      !getIllegalDialects())
+    return emitOpError() << "conversion target is not specified";
+  return success();
+}
+
+void transform::ApplyConversionPatternsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
+void transform::ApplyConversionPatternsOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
+    function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
+  result.addOperands(target);
+
+  {
+    OpBuilder::InsertionGuard g(builder);
+    Region *region1 = result.addRegion();
+    builder.createBlock(region1);
+    if (patternsBodyBuilder)
+      patternsBodyBuilder(builder, result.location);
+  }
+  {
+    OpBuilder::InsertionGuard g(builder);
+    Region *region2 = result.addRegion();
+    builder.createBlock(region2);
+    if (typeConverterBodyBuilder)
+      typeConverterBodyBuilder(builder, result.location);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ApplyLoopInvariantCodeMotionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index a9a5e43cc06774..8ac6d4ef3b9778 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -231,3 +231,51 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.canonicalization
   } {apply_cse} : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @full_dialect_conversion
+//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() : () -> memref<5xf32>
+//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+//  CHECK-NEXT:   return %[[cast]]
+func.func @full_dialect_conversion() -> tensor<5xf32> {
+  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_conversion_patterns to %0 {
+    transform.apply_conversion_patterns.transform.test_conversion_patterns
+  }, {
+    transform.apply_conversion_patterns.transform.test_type_converter
+  } {illegal_ops = ["test.foo"],
+     legal_ops = ["func.func", "func.return", "test.new_op"]}
+      : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below{{conversion target is not specified}}
+  transform.apply_conversion_patterns to %0 {
+    transform.apply_conversion_patterns.transform.test_conversion_patterns
+  }, {
+    transform.apply_conversion_patterns.transform.test_type_converter
+  } : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below{{pattern descriptor does not specify type converter and apply_conversion_patterns op has no default type converter}}
+  transform.apply_conversion_patterns to %0 {
+    // expected-note @below{{pattern descriptor op}}
+    transform.apply_conversion_patterns.transform.test_conversion_patterns
+  } {illegal_ops = ["test.foo"]} : !transform.any_op
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index a30f92e9f56532..ac7e186843e24e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -915,6 +915,65 @@ void mlir::test::TestProduceInvalidIR::getEffects(
   transform::modifiesPayload(effects);
 }
 
+namespace {
+/// Test conversion pattern that replaces ops with the "replace_with_new_op"
+/// attribute with "test.new_op".
+class ReplaceWithNewOpConversion : public ConversionPattern {
+public:
+  ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
+      : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
+                          /*benefit=*/1, context) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op->hasAttr("replace_with_new_op"))
+      return failure();
+    SmallVector<Type> newResultTypes;
+    if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
+                                                newResultTypes)))
+      return failure();
+    Operation *newOp = rewriter.create(
+        op->getLoc(),
+        OperationName("test.new_op", op->getContext()).getIdentifier(),
+        operands, newResultTypes);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
+                                              patterns.getContext());
+}
+
+namespace {
+/// Test type converter that converts tensor types to memref types.
+class TestTypeConverter : public TypeConverter {
+public:
+  TestTypeConverter() {
+    addConversion([](RankedTensorType type) -> Type {
+      return MemRefType::get(type.getShape(), type.getElementType());
+    });
+    addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                                 ValueRange inputs,
+                                 Location loc) -> std::optional<Value> {
+      if (inputs.size() != 1)
+        return std::nullopt;
+      return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+          .getResult(0);
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<::mlir::TypeConverter>
+mlir::test::TestTypeConverterOp::getTypeConverter() {
+  return std::make_unique<TestTypeConverter>();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index b11c59209a9732..41f318db68405b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -554,6 +554,24 @@ def ApplyTestPatternsOp
   let cppNamespace = "::mlir::test";
 }
 
+def ApplyTestConversionPatternsOp
+  : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_conversion_patterns",
+      [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestTypeConverterOp
+  : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
+      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 def TestReEnterRegionOp
   : Op<Transform_Dialect, "test_re_enter_region",
        [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,


        


More information about the Mlir-commits mailing list