[Mlir-commits] [mlir] bcfdb3e - [mlir][transform] Add apply_conversion_patterns op
Matthias Springer
llvmlistbot at llvm.org
Sun Aug 6 23:57:53 PDT 2023
Author: Matthias Springer
Date: 2023-08-07T08:49:55+02:00
New Revision: bcfdb3e4bc819c50c32c61070c5a1a86df808e49
URL: https://github.com/llvm/llvm-project/commit/bcfdb3e4bc819c50c32c61070c5a1a86df808e49
DIFF: https://github.com/llvm/llvm-project/commit/bcfdb3e4bc819c50c32c61070c5a1a86df808e49.diff
LOG: [mlir][transform] Add apply_conversion_patterns op
This transform op applies a dialect conversion to the targeted ops. Its design is similar to `apply_patterns`.
Patterns are specified in the first region of `apply_conversion_patterns`. They must implement the `ConversionPatternDescriptorOpInterface`. Regular rewrite patterns and dialect conversion patterns should not be mixed, so the interface is separate from the `PatternDescriptorOpInterface`.
The type converter is specified as the single op of the second region. It is optional; if no type converter is specified, it is expected that pattern descriptors provide their own type converters. If both the pattern descriptors and the `apply_conversion_patterns` op specify a type converter, the type converter of the pattern descriptor is used.
Differential Revision: https://reviews.llvm.org/D157109
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 472f642cfa05cf..114d79555dcef5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -16,6 +16,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace transform {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index eaab05766fc455..5ce21be223bcb2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -240,9 +240,13 @@ def FindPayloadReplacementOpInterface
def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
let description = [{
- This interface should be implemented by ops that select patterns of a
- `transform.apply_patterns` op. It provides a method to populate a rewrite
+ This interface should be implemented by ops that select rewrite patterns of
+ a `transform.apply_patterns` op. It provides a method to populate a rewrite
pattern set with patterns.
+
+ Note: Conversion patterns are rewrite patterns in MLIR, but they should not
+ be populated with `PatternDescriptorOpInterface` because they cannot be
+ used in a greedy pattern rewrite.
}];
let cppNamespace = "::mlir::transform";
@@ -250,11 +254,73 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
- Populate patterns into the given pattern set.
+ Populate rewrite patterns into the given pattern set.
}],
/*returnType=*/"void",
/*name=*/"populatePatterns",
- /*arguments=*/(ins "RewritePatternSet &":$patterns)
+ /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns)
+ >,
+ ];
+}
+
+def ConversionPatternDescriptorOpInterface
+ : OpInterface<"ConversionPatternDescriptorOpInterface"> {
+ let description = [{
+ This interface should be implemented by ops that select conversion patterns
+ of a `transform.apply_patterns` op. It provides a method to populate a
+ rewrite pattern set with conversion patterns.
+
+ Note: Non-conversion rewrite patterns should not be populated with
+ `ConversionPatternDescriptorOpInterface` because it is not generally safe
+ to use non-conversion rewrite patterns as part of a dialect conversion.
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate conversion patterns into the given pattern set with the
+ given type converter.
+ }],
+ /*returnType=*/"void",
+ /*name=*/"populatePatterns",
+ /*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter,
+ "::mlir::RewritePatternSet &":$patterns)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the type converter to be used with this pattern set. If no
+ type converter is specified, the default type converter of the enclosing
+ "apply_conversion_patterns" op is used.
+ }],
+ /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+ /*name=*/"getTypeConverter",
+ /*arguments=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return nullptr;"
+ >,
+ ];
+}
+
+def TypeConverterBuilderOpInterface
+ : OpInterface<"TypeConverterBuilderOpInterface"> {
+ let description = [{
+ This interface should be implemented by ops that specify a type converter
+ for a dialect conversion. Such ops can be used with
+ "apply_conversion_patterns".
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the type converter to be used with a dialect conversion.
+ }],
+ /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
+ /*name=*/"getTypeConverter",
+ /*arguments=*/(ins)
>,
];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 8a30205ee17680..c41c6d7768ac7f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -156,6 +156,84 @@ def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse",
}];
}
+def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]
+ # GraphRegionNoTerminator.traits> {
+ let summary = "Applies conversion patterns to the body of the targeted op";
+ let description = [{
+ This transform applies the specified conversion patterns to the targeted op
+ and all nested ops. By default, this transform applies a "full" dialect
+ conversion. If the `partial_conversion` unit attribute is present, this
+ transform applies a partial dialect conversion.
+
+ The patterns that should be applied are specified in the first graph region
+ of this op. They must implement the
+ `ConversionPatternDescriptorOpInterface`. The order in which patterns are
+ applied is unspecified; i.e., the ordering of ops in the region of this op
+ is irrelevant.
+
+ The second, optional graph region contains exactly one op that specifies
+ default type converter that should be used with this dialect conversion. If
+ provided, this op must implement the `TypeConverterBuilderOpInterface`.
+ Type converters are a property of conversion patterns: each conversion
+ pattern stores the type converter that should be used in its C++ class. Each
+ conversion pattern descriptor can optionally specify a type converter in its
+ `getTypeConverter` interface method. If no type converter is specified in
+ this method, the default type converter of the dialect conversion is used.
+ Default type converters are useful if the same type converter should be used
+ for multiple sets of conversion patterns. (Patterns that should not use this
+ default type converter specify their own type converter.)
+
+ The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
+ attributes specify the conversion target. At least one of those four
+ attributes must be specified.
+
+ This transform consumes the `target` handle and modifies the payload. It
+ does not produce any handles.
+
+ This transform fails silently if the dialect conversion was unsuccessful.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ OptionalAttr<StrArrayAttr>:$legal_ops,
+ OptionalAttr<StrArrayAttr>:$illegal_ops,
+ OptionalAttr<StrArrayAttr>:$legal_dialects,
+ OptionalAttr<StrArrayAttr>:$illegal_dialects,
+ UnitAttr:$partialConversion);
+ let results = (outs);
+ let regions = (region VariadicRegion<MaxSizedRegion<1>>:$regions);
+
+ let assemblyFormat = [{
+ `to` $target $regions attr-dict `:` type($target)
+ }];
+ let hasVerifier = 1;
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "Value":$target,
+ CArg<"function_ref<void(OpBuilder &, Location)>", "nullptr">:
+ $patternsBodyBuilder,
+ CArg<"function_ref<void(OpBuilder &, Location)>", "nullptr">:
+ $typeConverterBodyBuilder)>,
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::Region &getPatterns() {
+ return getRegion(0);
+ }
+
+ ::mlir::transform::TypeConverterBuilderOpInterface getDefaultTypeConverter() {
+ if (getNumRegions() < 2)
+ return {};
+ return ::llvm::cast<::mlir::transform::TypeConverterBuilderOpInterface>(
+ &getRegion(1).front().front());
+ }
+ }];
+}
+
def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index d7205ec02690e8..32c56e903268f7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -28,10 +28,15 @@ void transform::detail::checkImplementsTransformOpInterface(
*RegisteredOperationName::lookup(name, context);
assert((opName.hasInterface<TransformOpInterface>() ||
opName.hasInterface<PatternDescriptorOpInterface>() ||
+ opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
+ opName.hasInterface<TypeConverterBuilderOpInterface>() ||
opName.hasTrait<OpTrait::IsTerminator>()) &&
"non-terminator ops injected into the transform dialect must "
- "implement TransformOpInterface or PatternDescriptorOpInterface");
- if (!opName.hasInterface<PatternDescriptorOpInterface>()) {
+ "implement TransformOpInterface or PatternDescriptorOpInterface or "
+ "ConversionPatternDescriptorOpInterface");
+ if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
+ !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
+ !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
"ops injected into the transform dialect must implement "
"MemoryEffectsOpInterface");
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 5327a5f7f2524d..e6baf470199837 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/CSE.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/STLExtras.h"
@@ -478,6 +479,159 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
op.getCanonicalizationPatterns(patterns, ctx);
}
+//===----------------------------------------------------------------------===//
+// ApplyConversionPatternsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ MLIRContext *ctx = getContext();
+
+ // Default type converter is built on demand.
+ std::unique_ptr<TypeConverter> defaultTypeConverter;
+
+ // Configure conversion target.
+ ConversionTarget conversionTarget(*ctx);
+ if (getLegalOps())
+ for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
+ conversionTarget.addLegalOp(
+ OperationName(cast<StringAttr>(attr).getValue(), ctx));
+ if (getIllegalOps())
+ for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
+ conversionTarget.addIllegalOp(
+ OperationName(cast<StringAttr>(attr).getValue(), ctx));
+ if (getLegalDialects())
+ for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
+ conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
+ if (getIllegalDialects())
+ for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
+ conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
+
+ // Gather all specified patterns.
+ RewritePatternSet patterns(ctx);
+ if (!getPatterns().empty()) {
+ for (Operation &op : getPatterns().front()) {
+ auto descriptor =
+ cast<transform::ConversionPatternDescriptorOpInterface>(&op);
+
+ // Check if this pattern set specifies a type converter.
+ std::unique_ptr<TypeConverter> typeConverter =
+ descriptor.getTypeConverter();
+ TypeConverter *converter = nullptr;
+ if (typeConverter) {
+ converter = typeConverter.get();
+ } else {
+ // No type converter specified: Use the default type converter.
+ if (!defaultTypeConverter) {
+ // Instantiate the default type converter.
+ transform::TypeConverterBuilderOpInterface typeConverterBuilder =
+ getDefaultTypeConverter();
+ if (!typeConverterBuilder) {
+ auto diag = emitDefiniteFailure()
+ << "pattern descriptor does not specify type "
+ "converter and apply_conversion_patterns op has "
+ "no default type converter";
+ diag.attachNote(op.getLoc()) << "pattern descriptor op";
+ return diag;
+ }
+ defaultTypeConverter = typeConverterBuilder.getTypeConverter();
+ assert(defaultTypeConverter && "expected type converter");
+ }
+ converter = defaultTypeConverter.get();
+ }
+ descriptor.populatePatterns(*converter, patterns);
+ }
+ }
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ // Make sure that this transform is not applied to itself. Modifying the
+ // transform IR while it is being interpreted is generally dangerous.
+ DiagnosedSilenceableFailure payloadCheck =
+ ensurePayloadIsSeparateFromTransform(*this, target);
+ if (!payloadCheck.succeeded())
+ return payloadCheck;
+
+ LogicalResult status = failure();
+ if (getPartialConversion()) {
+ status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+ } else {
+ status = applyFullConversion(target, conversionTarget, frozenPatterns);
+ }
+
+ if (failed(status)) {
+ auto diag = emitSilenceableError() << "dialect conversion failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ApplyConversionPatternsOp::verify() {
+ if (getNumRegions() != 1 && getNumRegions() != 2)
+ return emitOpError() << "expected 1 or 2 regions";
+ if (!getPatterns().empty()) {
+ for (Operation &op : getPatterns().front()) {
+ if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
+ InFlightDiagnostic diag =
+ emitOpError() << "expected pattern children ops to implement "
+ "ConversionPatternDescriptorOpInterface";
+ diag.attachNote(op.getLoc()) << "op without interface";
+ return diag;
+ }
+ }
+ }
+ if (getNumRegions() == 2) {
+ Region &typeConverterRegion = getRegion(1);
+ if (!llvm::hasSingleElement(typeConverterRegion.front()))
+ return emitOpError()
+ << "expected exactly one op in default type converter region";
+ Operation *typeConverterOp = &typeConverterRegion.front().front();
+ if (!isa<transform::TypeConverterBuilderOpInterface>(typeConverterOp)) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expected default converter child op to "
+ "implement TypeConverterBuilderOpInterface";
+ diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
+ return diag;
+ }
+ }
+ if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() &&
+ !getIllegalDialects())
+ return emitOpError() << "conversion target is not specified";
+ return success();
+}
+
+void transform::ApplyConversionPatternsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::consumesHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
+void transform::ApplyConversionPatternsOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
+ function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
+ result.addOperands(target);
+
+ {
+ OpBuilder::InsertionGuard g(builder);
+ Region *region1 = result.addRegion();
+ builder.createBlock(region1);
+ if (patternsBodyBuilder)
+ patternsBodyBuilder(builder, result.location);
+ }
+ {
+ OpBuilder::InsertionGuard g(builder);
+ Region *region2 = result.addRegion();
+ builder.createBlock(region2);
+ if (typeConverterBodyBuilder)
+ typeConverterBodyBuilder(builder, result.location);
+ }
+}
+
//===----------------------------------------------------------------------===//
// ApplyLoopInvariantCodeMotionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index a9a5e43cc06774..8ac6d4ef3b9778 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -231,3 +231,51 @@ transform.sequence failures(propagate) {
transform.apply_patterns.canonicalization
} {apply_cse} : !transform.any_op
}
+
+// -----
+
+// CHECK-LABEL: func @full_dialect_conversion
+// CHECK-NEXT: %[[m:.*]] = "test.new_op"() : () -> memref<5xf32>
+// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+// CHECK-NEXT: return %[[cast]]
+func.func @full_dialect_conversion() -> tensor<5xf32> {
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.transform.test_conversion_patterns
+ }, {
+ transform.apply_conversion_patterns.transform.test_type_converter
+ } {illegal_ops = ["test.foo"],
+ legal_ops = ["func.func", "func.return", "test.new_op"]}
+ : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below{{conversion target is not specified}}
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.transform.test_conversion_patterns
+ }, {
+ transform.apply_conversion_patterns.transform.test_type_converter
+ } : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below{{pattern descriptor does not specify type converter and apply_conversion_patterns op has no default type converter}}
+ transform.apply_conversion_patterns to %0 {
+ // expected-note @below{{pattern descriptor op}}
+ transform.apply_conversion_patterns.transform.test_conversion_patterns
+ } {illegal_ops = ["test.foo"]} : !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index a30f92e9f56532..ac7e186843e24e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -915,6 +915,65 @@ void mlir::test::TestProduceInvalidIR::getEffects(
transform::modifiesPayload(effects);
}
+namespace {
+/// Test conversion pattern that replaces ops with the "replace_with_new_op"
+/// attribute with "test.new_op".
+class ReplaceWithNewOpConversion : public ConversionPattern {
+public:
+ ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
+ /*benefit=*/1, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op->hasAttr("replace_with_new_op"))
+ return failure();
+ SmallVector<Type> newResultTypes;
+ if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
+ newResultTypes)))
+ return failure();
+ Operation *newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ operands, newResultTypes);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
+ patterns.getContext());
+}
+
+namespace {
+/// Test type converter that converts tensor types to memref types.
+class TestTypeConverter : public TypeConverter {
+public:
+ TestTypeConverter() {
+ addConversion([](RankedTensorType type) -> Type {
+ return MemRefType::get(type.getShape(), type.getElementType());
+ });
+ addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1)
+ return std::nullopt;
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<::mlir::TypeConverter>
+mlir::test::TestTypeConverterOp::getTypeConverter() {
+ return std::make_unique<TestTypeConverter>();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index b11c59209a9732..41f318db68405b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -554,6 +554,24 @@ def ApplyTestPatternsOp
let cppNamespace = "::mlir::test";
}
+def ApplyTestConversionPatternsOp
+ : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_conversion_patterns",
+ [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestTypeConverterOp
+ : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
+ [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
def TestReEnterRegionOp
: Op<Transform_Dialect, "test_re_enter_region",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
More information about the Mlir-commits
mailing list