[Mlir-commits] [mlir] 8a7a713 - [mlir] somewhat decompose TestDialect.cpp
Alex Zinenko
llvmlistbot at llvm.org
Thu Jul 27 04:36:31 PDT 2023
Author: Alex Zinenko
Date: 2023-07-27T11:36:24Z
New Revision: 8a7a7137ff60d268bc7e67845e54baa9a5878cb1
URL: https://github.com/llvm/llvm-project/commit/8a7a7137ff60d268bc7e67845e54baa9a5878cb1
DIFF: https://github.com/llvm/llvm-project/commit/8a7a7137ff60d268bc7e67845e54baa9a5878cb1.diff
LOG: [mlir] somewhat decompose TestDialect.cpp
TestDialect.cpp along with the ODS-generated files amounts to around
100k LoC and takes a significant amount of time to compile. Factor out
the test ops related to testing the sytnax and assembly format, which
are a relatively large and well delimited group, into a separate set of
files.
Also factor out dialect interfaces into a separate file.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D155947
Added:
mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
mlir/test/lib/Dialect/Test/TestOpsSyntax.h
mlir/test/lib/Dialect/Test/TestOpsSyntax.td
Modified:
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.td
mlir/test/lib/Dialect/Test/TestOps.td
mlir/unittests/IR/AdaptorTest.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 2d1a1df8ea7069..8d26fd4e1c545c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -36,6 +36,11 @@ mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestOpsIncGen)
+set(LLVM_TARGET_DEFINITIONS TestOpsSyntax.td)
+mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls)
+mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
@@ -44,6 +49,8 @@ add_mlir_library(MLIRTestDialect
TestPatterns.cpp
TestTraits.cpp
TestTypes.cpp
+ TestOpsSyntax.cpp
+ TestDialectInterfaces.cpp
EXCLUDE_FROM_LIBMLIR
@@ -53,6 +60,7 @@ add_mlir_library(MLIRTestDialect
MLIRTestInterfaceIncGen
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
+ MLIRTestOpsSyntaxIncGen
LINK_LIBS PUBLIC
MLIRControlFlowInterfaces
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 072f6ff4b84d33..62f8f6865181dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -28,7 +28,6 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
-#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -120,357 +119,6 @@ void test::registerTestDialect(DialectRegistry ®istry) {
registry.insert<TestDialect>();
}
-//===----------------------------------------------------------------------===//
-// TestDialect version utilities
-//===----------------------------------------------------------------------===//
-
-struct TestDialectVersion : public DialectVersion {
- uint32_t major = 2;
- uint32_t minor = 0;
-};
-
-//===----------------------------------------------------------------------===//
-// TestDialect Interfaces
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Testing the correctness of some traits.
-static_assert(
- llvm::is_detected<OpTrait::has_implicit_terminator_t,
- SingleBlockImplicitTerminatorOp>::value,
- "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
-static_assert(OpTrait::hasSingleBlockImplicitTerminator<
- SingleBlockImplicitTerminatorOp>::value,
- "hasSingleBlockImplicitTerminator does not match "
- "SingleBlockImplicitTerminatorOp");
-
-struct TestResourceBlobManagerInterface
- : public ResourceBlobManagerDialectInterfaceBase<
- TestDialectResourceBlobHandle> {
- using ResourceBlobManagerDialectInterfaceBase<
- TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
-};
-
-namespace {
-enum test_encoding { k_attr_params = 0 };
-}
-
-// Test support for interacting with the Bytecode reader/writer.
-struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
- using BytecodeDialectInterface::BytecodeDialectInterface;
- TestBytecodeDialectInterface(Dialect *dialect)
- : BytecodeDialectInterface(dialect) {}
-
- LogicalResult writeAttribute(Attribute attr,
- DialectBytecodeWriter &writer) const final {
- if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
- writer.writeVarInt(test_encoding::k_attr_params);
- writer.writeVarInt(concreteAttr.getV0());
- writer.writeVarInt(concreteAttr.getV1());
- return success();
- }
- return failure();
- }
-
- Attribute readAttribute(DialectBytecodeReader &reader,
- const DialectVersion &version_) const final {
- const auto &version = static_cast<const TestDialectVersion &>(version_);
- if (version.major < 2)
- return readAttrOldEncoding(reader);
- if (version.major == 2 && version.minor == 0)
- return readAttrNewEncoding(reader);
- // Forbid reading future versions by returning nullptr.
- return Attribute();
- }
-
- // Emit a specific version of the dialect.
- void writeVersion(DialectBytecodeWriter &writer) const final {
- auto version = TestDialectVersion();
- writer.writeVarInt(version.major); // major
- writer.writeVarInt(version.minor); // minor
- }
-
- std::unique_ptr<DialectVersion>
- readVersion(DialectBytecodeReader &reader) const final {
- uint64_t major, minor;
- if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor)))
- return nullptr;
- auto version = std::make_unique<TestDialectVersion>();
- version->major = major;
- version->minor = minor;
- return version;
- }
-
- LogicalResult upgradeFromVersion(Operation *topLevelOp,
- const DialectVersion &version_) const final {
- const auto &version = static_cast<const TestDialectVersion &>(version_);
- if ((version.major == 2) && (version.minor == 0))
- return success();
- if (version.major > 2 || (version.major == 2 && version.minor > 0)) {
- return topLevelOp->emitError()
- << "current test dialect version is 2.0, can't parse version: "
- << version.major << "." << version.minor;
- }
- // Prior version 2.0, the old op supported only a single attribute called
- // "dimensions". We can perform the upgrade.
- topLevelOp->walk([](TestVersionedOpA op) {
- if (auto dims = op->getAttr("dimensions")) {
- op->removeAttr("dimensions");
- op->setAttr("dims", dims);
- }
- op->setAttr("modifier", BoolAttr::get(op->getContext(), false));
- });
- return success();
- }
-
-private:
- Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
- uint64_t encoding;
- if (failed(reader.readVarInt(encoding)) ||
- encoding != test_encoding::k_attr_params)
- return Attribute();
- // The new encoding has v0 first, v1 second.
- uint64_t v0, v1;
- if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
- return Attribute();
- return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
- static_cast<int>(v1));
- }
-
- Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
- uint64_t encoding;
- if (failed(reader.readVarInt(encoding)) ||
- encoding != test_encoding::k_attr_params)
- return Attribute();
- // The old encoding has v1 first, v0 second.
- uint64_t v0, v1;
- if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
- return Attribute();
- return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
- static_cast<int>(v1));
- }
-};
-
-// Test support for interacting with the AsmPrinter.
-struct TestOpAsmInterface : public OpAsmDialectInterface {
- using OpAsmDialectInterface::OpAsmDialectInterface;
- TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
- : OpAsmDialectInterface(dialect), blobManager(mgr) {}
-
- //===------------------------------------------------------------------===//
- // Aliases
- //===------------------------------------------------------------------===//
-
- AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
- StringAttr strAttr = dyn_cast<StringAttr>(attr);
- if (!strAttr)
- return AliasResult::NoAlias;
-
- // Check the contents of the string attribute to see what the test alias
- // should be named.
- std::optional<StringRef> aliasName =
- StringSwitch<std::optional<StringRef>>(strAttr.getValue())
- .Case("alias_test:dot_in_name", StringRef("test.alias"))
- .Case("alias_test:trailing_digit", StringRef("test_alias0"))
- .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
- .Case("alias_test:sanitize_conflict_a",
- StringRef("test_alias_conflict0"))
- .Case("alias_test:sanitize_conflict_b",
- StringRef("test_alias_conflict0_"))
- .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
- .Default(std::nullopt);
- if (!aliasName)
- return AliasResult::NoAlias;
-
- os << *aliasName;
- return AliasResult::FinalAlias;
- }
-
- AliasResult getAlias(Type type, raw_ostream &os) const final {
- if (auto tupleType = dyn_cast<TupleType>(type)) {
- if (tupleType.size() > 0 &&
- llvm::all_of(tupleType.getTypes(), [](Type elemType) {
- return isa<SimpleAType>(elemType);
- })) {
- os << "test_tuple";
- return AliasResult::FinalAlias;
- }
- }
- if (auto intType = dyn_cast<TestIntegerType>(type)) {
- if (intType.getSignedness() ==
- TestIntegerType::SignednessSemantics::Unsigned &&
- intType.getWidth() == 8) {
- os << "test_ui8";
- return AliasResult::FinalAlias;
- }
- }
- if (auto recType = dyn_cast<TestRecursiveType>(type)) {
- if (recType.getName() == "type_to_alias") {
- // We only make alias for a specific recursive type.
- os << "testrec";
- return AliasResult::FinalAlias;
- }
- }
- return AliasResult::NoAlias;
- }
-
- //===------------------------------------------------------------------===//
- // Resources
- //===------------------------------------------------------------------===//
-
- std::string
- getResourceKey(const AsmDialectResourceHandle &handle) const override {
- return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
- }
-
- FailureOr<AsmDialectResourceHandle>
- declareResource(StringRef key) const final {
- return blobManager.insert(key);
- }
-
- LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
- FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
- if (failed(blob))
- return failure();
-
- // Update the blob for this entry.
- blobManager.update(entry.getKey(), std::move(*blob));
- return success();
- }
-
- void
- buildResources(Operation *op,
- const SetVector<AsmDialectResourceHandle> &referencedResources,
- AsmResourceBuilder &provider) const final {
- blobManager.buildResources(provider, referencedResources.getArrayRef());
- }
-
-private:
- /// The blob manager for the dialect.
- TestResourceBlobManagerInterface &blobManager;
-};
-
-struct TestDialectFoldInterface : public DialectFoldInterface {
- using DialectFoldInterface::DialectFoldInterface;
-
- /// Registered hook to check if the given region, which is attached to an
- /// operation that is *not* isolated from above, should be used when
- /// materializing constants.
- bool shouldMaterializeInto(Region *region) const final {
- // If this is a one region operation, then insert into it.
- return isa<OneRegionOp>(region->getParentOp());
- }
-};
-
-/// This class defines the interface for handling inlining with standard
-/// operations.
-struct TestInlinerInterface : public DialectInlinerInterface {
- using DialectInlinerInterface::DialectInlinerInterface;
-
- //===--------------------------------------------------------------------===//
- // Analysis Hooks
- //===--------------------------------------------------------------------===//
-
- bool isLegalToInline(Operation *call, Operation *callable,
- bool wouldBeCloned) const final {
- // Don't allow inlining calls that are marked `noinline`.
- return !call->hasAttr("noinline");
- }
- bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
- // Inlining into test dialect regions is legal.
- return true;
- }
- bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
- return true;
- }
-
- bool shouldAnalyzeRecursively(Operation *op) const final {
- // Analyze recursively if this is not a functional region operation, it
- // froms a separate functional scope.
- return !isa<FunctionalRegionOp>(op);
- }
-
- //===--------------------------------------------------------------------===//
- // Transformation Hooks
- //===--------------------------------------------------------------------===//
-
- /// Handle the given inlined terminator by replacing it with a new operation
- /// as necessary.
- void handleTerminator(Operation *op,
- ArrayRef<Value> valuesToRepl) const final {
- // Only handle "test.return" here.
- auto returnOp = dyn_cast<TestReturnOp>(op);
- if (!returnOp)
- return;
-
- // Replace the values directly with the return operands.
- assert(returnOp.getNumOperands() == valuesToRepl.size());
- for (const auto &it : llvm::enumerate(returnOp.getOperands()))
- valuesToRepl[it.index()].replaceAllUsesWith(it.value());
- }
-
- /// Attempt to materialize a conversion for a type mismatch between a call
- /// from this dialect, and a callable region. This method should generate an
- /// operation that takes 'input' as the only operand, and produces a single
- /// result of 'resultType'. If a conversion can not be generated, nullptr
- /// should be returned.
- Operation *materializeCallConversion(OpBuilder &builder, Value input,
- Type resultType,
- Location conversionLoc) const final {
- // Only allow conversion for i16/i32 types.
- if (!(resultType.isSignlessInteger(16) ||
- resultType.isSignlessInteger(32)) ||
- !(input.getType().isSignlessInteger(16) ||
- input.getType().isSignlessInteger(32)))
- return nullptr;
- return builder.create<TestCastOp>(conversionLoc, resultType, input);
- }
-
- Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
- Value argument,
- DictionaryAttr argumentAttrs) const final {
- if (!argumentAttrs.contains("test.handle_argument"))
- return argument;
- return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
- argument);
- }
-
- Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
- Value result, DictionaryAttr resultAttrs) const final {
- if (!resultAttrs.contains("test.handle_result"))
- return result;
- return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
- result);
- }
-
- void processInlinedCallBlocks(
- Operation *call,
- iterator_range<Region::iterator> inlinedBlocks) const final {
- if (!isa<ConversionCallOp>(call))
- return;
-
- // Set attributed on all ops in the inlined blocks.
- for (Block &block : inlinedBlocks) {
- block.walk([&](Operation *op) {
- op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
- });
- }
- }
-};
-
-struct TestReductionPatternInterface : public DialectReductionPatternInterface {
-public:
- TestReductionPatternInterface(Dialect *dialect)
- : DialectReductionPatternInterface(dialect) {}
-
- void populateReductionPatterns(RewritePatternSet &patterns) const final {
- populateTestReductionPatterns(patterns);
- }
-};
-
-} // namespace
-
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
@@ -557,16 +205,12 @@ void TestDialect::initialize() {
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
+ registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
-
- auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
- addInterface<TestOpAsmInterface>(blobInterface);
-
- addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
- TestReductionPatternInterface, TestBytecodeDialectInterface>();
+ registerInterfaces();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific
@@ -583,15 +227,6 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
return builder.create<TestOpConstant>(loc, type, value);
}
-::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
- ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
- ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
- OpaqueProperties properties, ::mlir::RegionRange regions,
- ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
- return ::mlir::success();
-}
-
void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
OperationName opName) {
if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
@@ -785,224 +420,6 @@ void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldToCallOpPattern>(context);
}
-//===----------------------------------------------------------------------===//
-// Test Format* operations
-//===----------------------------------------------------------------------===//
-
-//===----------------------------------------------------------------------===//
-// Parsing
-
-static ParseResult parseCustomOptionalOperand(
- OpAsmParser &parser,
- std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
- if (succeeded(parser.parseOptionalLParen())) {
- optOperand.emplace();
- if (parser.parseOperand(*optOperand) || parser.parseRParen())
- return failure();
- }
- return success();
-}
-
-static ParseResult parseCustomDirectiveOperands(
- OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
- std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
- if (parser.parseOperand(operand))
- return failure();
- if (succeeded(parser.parseOptionalComma())) {
- optOperand.emplace();
- if (parser.parseOperand(*optOperand))
- return failure();
- }
- if (parser.parseArrow() || parser.parseLParen() ||
- parser.parseOperandList(varOperands) || parser.parseRParen())
- return failure();
- return success();
-}
-static ParseResult
-parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
- Type &optOperandType,
- SmallVectorImpl<Type> &varOperandTypes) {
- if (parser.parseColon())
- return failure();
-
- if (parser.parseType(operandType))
- return failure();
- if (succeeded(parser.parseOptionalComma())) {
- if (parser.parseType(optOperandType))
- return failure();
- }
- if (parser.parseArrow() || parser.parseLParen() ||
- parser.parseTypeList(varOperandTypes) || parser.parseRParen())
- return failure();
- return success();
-}
-static ParseResult
-parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
- Type optOperandType,
- const SmallVectorImpl<Type> &varOperandTypes) {
- if (parser.parseKeyword("type_refs_capture"))
- return failure();
-
- Type operandType2, optOperandType2;
- SmallVector<Type, 1> varOperandTypes2;
- if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
- varOperandTypes2))
- return failure();
-
- if (operandType != operandType2 || optOperandType != optOperandType2 ||
- varOperandTypes != varOperandTypes2)
- return failure();
-
- return success();
-}
-static ParseResult parseCustomDirectiveOperandsAndTypes(
- OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
- std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
- Type &operandType, Type &optOperandType,
- SmallVectorImpl<Type> &varOperandTypes) {
- if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
- parseCustomDirectiveResults(parser, operandType, optOperandType,
- varOperandTypes))
- return failure();
- return success();
-}
-static ParseResult parseCustomDirectiveRegions(
- OpAsmParser &parser, Region ®ion,
- SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
- if (parser.parseRegion(region))
- return failure();
- if (failed(parser.parseOptionalComma()))
- return success();
- std::unique_ptr<Region> varRegion = std::make_unique<Region>();
- if (parser.parseRegion(*varRegion))
- return failure();
- varRegions.emplace_back(std::move(varRegion));
- return success();
-}
-static ParseResult
-parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
- SmallVectorImpl<Block *> &varSuccessors) {
- if (parser.parseSuccessor(successor))
- return failure();
- if (failed(parser.parseOptionalComma()))
- return success();
- Block *varSuccessor;
- if (parser.parseSuccessor(varSuccessor))
- return failure();
- varSuccessors.append(2, varSuccessor);
- return success();
-}
-static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
- IntegerAttr &attr,
- IntegerAttr &optAttr) {
- if (parser.parseAttribute(attr))
- return failure();
- if (succeeded(parser.parseOptionalComma())) {
- if (parser.parseAttribute(optAttr))
- return failure();
- }
- return success();
-}
-static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
- mlir::StringAttr &attr) {
- return parser.parseAttribute(attr);
-}
-static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
- NamedAttrList &attrs) {
- return parser.parseOptionalAttrDict(attrs);
-}
-static ParseResult parseCustomDirectiveOptionalOperandRef(
- OpAsmParser &parser,
- std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
- int64_t operandCount = 0;
- if (parser.parseInteger(operandCount))
- return failure();
- bool expectedOptionalOperand = operandCount == 0;
- return success(expectedOptionalOperand != optOperand.has_value());
-}
-
-//===----------------------------------------------------------------------===//
-// Printing
-
-static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
- Value optOperand) {
- if (optOperand)
- printer << "(" << optOperand << ") ";
-}
-
-static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
- Value operand, Value optOperand,
- OperandRange varOperands) {
- printer << operand;
- if (optOperand)
- printer << ", " << optOperand;
- printer << " -> (" << varOperands << ")";
-}
-static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
- Type operandType, Type optOperandType,
- TypeRange varOperandTypes) {
- printer << " : " << operandType;
- if (optOperandType)
- printer << ", " << optOperandType;
- printer << " -> (" << varOperandTypes << ")";
-}
-static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
- Operation *op, Type operandType,
- Type optOperandType,
- TypeRange varOperandTypes) {
- printer << " type_refs_capture ";
- printCustomDirectiveResults(printer, op, operandType, optOperandType,
- varOperandTypes);
-}
-static void printCustomDirectiveOperandsAndTypes(
- OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
- OperandRange varOperands, Type operandType, Type optOperandType,
- TypeRange varOperandTypes) {
- printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
- printCustomDirectiveResults(printer, op, operandType, optOperandType,
- varOperandTypes);
-}
-static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
- Region ®ion,
- MutableArrayRef<Region> varRegions) {
- printer.printRegion(region);
- if (!varRegions.empty()) {
- printer << ", ";
- for (Region ®ion : varRegions)
- printer.printRegion(region);
- }
-}
-static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
- Block *successor,
- SuccessorRange varSuccessors) {
- printer << successor;
- if (!varSuccessors.empty())
- printer << ", " << varSuccessors.front();
-}
-static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
- Attribute attribute,
- Attribute optAttribute) {
- printer << attribute;
- if (optAttribute)
- printer << ", " << optAttribute;
-}
-static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
- Attribute attribute) {
- printer << attribute;
-}
-static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
- DictionaryAttr attrs) {
- printer.printOptionalAttrDict(attrs.getValue());
-}
-
-static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
- Operation *op,
- Value optOperand) {
- printer << (optOperand ? "1" : "0");
-}
-
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
@@ -1060,249 +477,6 @@ void AffineScopeOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
-//===----------------------------------------------------------------------===//
-// Test parser.
-//===----------------------------------------------------------------------===//
-
-ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parser.parseOptionalColon())
- return success();
- uint64_t numResults;
- if (parser.parseInteger(numResults))
- return failure();
-
- IndexType type = parser.getBuilder().getIndexType();
- for (unsigned i = 0; i < numResults; ++i)
- result.addTypes(type);
- return success();
-}
-
-void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
- if (unsigned numResults = getNumResults())
- p << " : " << numResults;
-}
-
-ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
- OperationState &result) {
- StringRef keyword;
- if (parser.parseKeyword(&keyword))
- return failure();
- result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
- return success();
-}
-
-void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
-
-ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
- OperationState &result) {
- std::vector<char> bytes;
- if (parser.parseBase64Bytes(&bytes))
- return failure();
- result.addAttribute("b64", parser.getBuilder().getStringAttr(
- StringRef(&bytes.front(), bytes.size())));
- return success();
-}
-
-void ParseB64BytesOp::print(OpAsmPrinter &p) {
- p << " \"" << llvm::encodeBase64(getB64()) << "\"";
-}
-
-//===----------------------------------------------------------------------===//
-// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
-
-ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parser.parseKeyword("wraps"))
- return failure();
-
- // Parse the wrapped op in a region
- Region &body = *result.addRegion();
- body.push_back(new Block);
- Block &block = body.back();
- Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
- if (!wrappedOp)
- return failure();
-
- // Create a return terminator in the inner region, pass as operand to the
- // terminator the returned values from the wrapped operation.
- SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
- OpBuilder builder(parser.getContext());
- builder.setInsertionPointToEnd(&block);
- builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
-
- // Get the results type for the wrapping op from the terminator operands.
- Operation &returnOp = body.back().back();
- result.types.append(returnOp.operand_type_begin(),
- returnOp.operand_type_end());
-
- // Use the location of the wrapped op for the "test.wrapping_region" op.
- result.location = wrappedOp->getLoc();
-
- return success();
-}
-
-void WrappingRegionOp::print(OpAsmPrinter &p) {
- p << " wraps ";
- p.printGenericOp(&getRegion().front().front());
-}
-
-//===----------------------------------------------------------------------===//
-// Test PrettyPrintedRegionOp - exercising the following parser APIs
-// parseGenericOperationAfterOpName
-// parseCustomOperationName
-//===----------------------------------------------------------------------===//
-
-ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
-
- SMLoc loc = parser.getCurrentLocation();
- Location currLocation = parser.getEncodedSourceLoc(loc);
-
- // Parse the operands.
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
- if (parser.parseOperandList(operands))
- return failure();
-
- // Check if we are parsing the pretty-printed version
- // test.pretty_printed_region start <inner-op> end : <functional-type>
- // Else fallback to parsing the "non pretty-printed" version.
- if (!succeeded(parser.parseOptionalKeyword("start")))
- return parser.parseGenericOperationAfterOpName(result,
- llvm::ArrayRef(operands));
-
- FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
- if (failed(parseOpNameInfo))
- return failure();
-
- StringAttr innerOpName = parseOpNameInfo->getIdentifier();
-
- FunctionType opFntype;
- std::optional<Location> explicitLoc;
- if (parser.parseKeyword("end") || parser.parseColon() ||
- parser.parseType(opFntype) ||
- parser.parseOptionalLocationSpecifier(explicitLoc))
- return failure();
-
- // If location of the op is explicitly provided, then use it; Else use
- // the parser's current location.
- Location opLoc = explicitLoc.value_or(currLocation);
-
- // Derive the SSA-values for op's operands.
- if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
- result.operands))
- return failure();
-
- // Add a region for op.
- Region ®ion = *result.addRegion();
-
- // Create a basic-block inside op's region.
- Block &block = region.emplaceBlock();
-
- // Create and insert an "inner-op" operation in the block.
- // Just for testing purposes, we can assume that inner op is a binary op with
- // result and operand types all same as the test-op's first operand.
- Type innerOpType = opFntype.getInput(0);
- Value lhs = block.addArgument(innerOpType, opLoc);
- Value rhs = block.addArgument(innerOpType, opLoc);
-
- OpBuilder builder(parser.getBuilder().getContext());
- builder.setInsertionPointToStart(&block);
-
- Operation *innerOp =
- builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
-
- // Insert a return statement in the block returning the inner-op's result.
- builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
-
- // Populate the op operation-state with result-type and location.
- result.addTypes(opFntype.getResults());
- result.location = innerOp->getLoc();
-
- return success();
-}
-
-void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
- p << ' ';
- p.printOperands(getOperands());
-
- Operation &innerOp = getRegion().front().front();
- // Assuming that region has a single non-terminator inner-op, if the inner-op
- // meets some criteria (which in this case is a simple one based on the name
- // of inner-op), then we can print the entire region in a succinct way.
- // Here we assume that the prototype of "test.special.op" can be trivially
- // derived while parsing it back.
- if (innerOp.getName().getStringRef().equals("test.special.op")) {
- p << " start test.special.op end";
- } else {
- p << " (";
- p.printRegion(getRegion());
- p << ")";
- }
-
- p << " : ";
- p.printFunctionalType(*this);
-}
-
-//===----------------------------------------------------------------------===//
-// Test PolyForOp - parse list of region arguments.
-//===----------------------------------------------------------------------===//
-
-ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::Argument, 4> ivsInfo;
- // Parse list of region arguments without a delimiter.
- if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
- return failure();
-
- // Parse the body region.
- Region *body = result.addRegion();
- for (auto &iv : ivsInfo)
- iv.type = parser.getBuilder().getIndexType();
- return parser.parseRegion(*body, ivsInfo);
-}
-
-void PolyForOp::print(OpAsmPrinter &p) {
- p << " ";
- llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
- p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
- });
- p << " ";
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
- OpAsmSetValueNameFn setNameFn) {
- auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
- if (!arrayAttr)
- return;
- auto args = getRegion().front().getArguments();
- auto e = std::min(arrayAttr.size(), args.size());
- for (unsigned i = 0; i < e; ++i) {
- if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
- setNameFn(args[i], strAttr.getValue());
- }
-}
-
-//===----------------------------------------------------------------------===//
-// TestAttrWithLoc - parse/printOptionalLocationSpecifier
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
- std::optional<Location> result;
- SMLoc sourceLoc = p.getCurrentLocation();
- if (p.parseOptionalLocationSpecifier(result))
- return failure();
- if (result)
- loc = *result;
- else
- loc = p.getEncodedSourceLoc(sourceLoc);
- return success();
-}
-
-static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
- p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
-}
-
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 01116a29367b27..5dcea10a57fe06 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -29,7 +29,9 @@ def Test_Dialect : Dialect {
let extraClassDeclaration = [{
void registerAttributes();
+ void registerInterfaces();
void registerTypes();
+ void registerOpsSyntax();
// Provides a custom printing/parsing for some operations.
::std::optional<ParseOpHook>
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
new file mode 100644
index 00000000000000..7315b253df998e
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -0,0 +1,374 @@
+//===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// TestDialect version utilities
+//===----------------------------------------------------------------------===//
+
+struct TestDialectVersion : public DialectVersion {
+ uint32_t major = 2;
+ uint32_t minor = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// TestDialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Testing the correctness of some traits.
+static_assert(
+ llvm::is_detected<OpTrait::has_implicit_terminator_t,
+ SingleBlockImplicitTerminatorOp>::value,
+ "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
+static_assert(OpTrait::hasSingleBlockImplicitTerminator<
+ SingleBlockImplicitTerminatorOp>::value,
+ "hasSingleBlockImplicitTerminator does not match "
+ "SingleBlockImplicitTerminatorOp");
+
+struct TestResourceBlobManagerInterface
+ : public ResourceBlobManagerDialectInterfaceBase<
+ TestDialectResourceBlobHandle> {
+ using ResourceBlobManagerDialectInterfaceBase<
+ TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
+};
+
+namespace {
+enum test_encoding { k_attr_params = 0 };
+}
+
+// Test support for interacting with the Bytecode reader/writer.
+struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
+ using BytecodeDialectInterface::BytecodeDialectInterface;
+ TestBytecodeDialectInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const final {
+ if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
+ writer.writeVarInt(test_encoding::k_attr_params);
+ writer.writeVarInt(concreteAttr.getV0());
+ writer.writeVarInt(concreteAttr.getV1());
+ return success();
+ }
+ return failure();
+ }
+
+ Attribute readAttribute(DialectBytecodeReader &reader,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast<const TestDialectVersion &>(version_);
+ if (version.major < 2)
+ return readAttrOldEncoding(reader);
+ if (version.major == 2 && version.minor == 0)
+ return readAttrNewEncoding(reader);
+ // Forbid reading future versions by returning nullptr.
+ return Attribute();
+ }
+
+ // Emit a specific version of the dialect.
+ void writeVersion(DialectBytecodeWriter &writer) const final {
+ auto version = TestDialectVersion();
+ writer.writeVarInt(version.major); // major
+ writer.writeVarInt(version.minor); // minor
+ }
+
+ std::unique_ptr<DialectVersion>
+ readVersion(DialectBytecodeReader &reader) const final {
+ uint64_t major, minor;
+ if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor)))
+ return nullptr;
+ auto version = std::make_unique<TestDialectVersion>();
+ version->major = major;
+ version->minor = minor;
+ return version;
+ }
+
+ LogicalResult upgradeFromVersion(Operation *topLevelOp,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast<const TestDialectVersion &>(version_);
+ if ((version.major == 2) && (version.minor == 0))
+ return success();
+ if (version.major > 2 || (version.major == 2 && version.minor > 0)) {
+ return topLevelOp->emitError()
+ << "current test dialect version is 2.0, can't parse version: "
+ << version.major << "." << version.minor;
+ }
+ // Prior version 2.0, the old op supported only a single attribute called
+ // "dimensions". We can perform the upgrade.
+ topLevelOp->walk([](TestVersionedOpA op) {
+ if (auto dims = op->getAttr("dimensions")) {
+ op->removeAttr("dimensions");
+ op->setAttr("dims", dims);
+ }
+ op->setAttr("modifier", BoolAttr::get(op->getContext(), false));
+ });
+ return success();
+ }
+
+private:
+ Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The new encoding has v0 first, v1 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
+ static_cast<int>(v1));
+ }
+
+ Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The old encoding has v1 first, v0 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
+ static_cast<int>(v1));
+ }
+};
+
+// Test support for interacting with the AsmPrinter.
+struct TestOpAsmInterface : public OpAsmDialectInterface {
+ using OpAsmDialectInterface::OpAsmDialectInterface;
+ TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
+ : OpAsmDialectInterface(dialect), blobManager(mgr) {}
+
+ //===------------------------------------------------------------------===//
+ // Aliases
+ //===------------------------------------------------------------------===//
+
+ AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
+ StringAttr strAttr = dyn_cast<StringAttr>(attr);
+ if (!strAttr)
+ return AliasResult::NoAlias;
+
+ // Check the contents of the string attribute to see what the test alias
+ // should be named.
+ std::optional<StringRef> aliasName =
+ StringSwitch<std::optional<StringRef>>(strAttr.getValue())
+ .Case("alias_test:dot_in_name", StringRef("test.alias"))
+ .Case("alias_test:trailing_digit", StringRef("test_alias0"))
+ .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+ .Case("alias_test:sanitize_conflict_a",
+ StringRef("test_alias_conflict0"))
+ .Case("alias_test:sanitize_conflict_b",
+ StringRef("test_alias_conflict0_"))
+ .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
+ .Default(std::nullopt);
+ if (!aliasName)
+ return AliasResult::NoAlias;
+
+ os << *aliasName;
+ return AliasResult::FinalAlias;
+ }
+
+ AliasResult getAlias(Type type, raw_ostream &os) const final {
+ if (auto tupleType = dyn_cast<TupleType>(type)) {
+ if (tupleType.size() > 0 &&
+ llvm::all_of(tupleType.getTypes(), [](Type elemType) {
+ return isa<SimpleAType>(elemType);
+ })) {
+ os << "test_tuple";
+ return AliasResult::FinalAlias;
+ }
+ }
+ if (auto intType = dyn_cast<TestIntegerType>(type)) {
+ if (intType.getSignedness() ==
+ TestIntegerType::SignednessSemantics::Unsigned &&
+ intType.getWidth() == 8) {
+ os << "test_ui8";
+ return AliasResult::FinalAlias;
+ }
+ }
+ if (auto recType = dyn_cast<TestRecursiveType>(type)) {
+ if (recType.getName() == "type_to_alias") {
+ // We only make alias for a specific recursive type.
+ os << "testrec";
+ return AliasResult::FinalAlias;
+ }
+ }
+ return AliasResult::NoAlias;
+ }
+
+ //===------------------------------------------------------------------===//
+ // Resources
+ //===------------------------------------------------------------------===//
+
+ std::string
+ getResourceKey(const AsmDialectResourceHandle &handle) const override {
+ return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
+ }
+
+ FailureOr<AsmDialectResourceHandle>
+ declareResource(StringRef key) const final {
+ return blobManager.insert(key);
+ }
+
+ LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
+ FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
+ if (failed(blob))
+ return failure();
+
+ // Update the blob for this entry.
+ blobManager.update(entry.getKey(), std::move(*blob));
+ return success();
+ }
+
+ void
+ buildResources(Operation *op,
+ const SetVector<AsmDialectResourceHandle> &referencedResources,
+ AsmResourceBuilder &provider) const final {
+ blobManager.buildResources(provider, referencedResources.getArrayRef());
+ }
+
+private:
+ /// The blob manager for the dialect.
+ TestResourceBlobManagerInterface &blobManager;
+};
+
+struct TestDialectFoldInterface : public DialectFoldInterface {
+ using DialectFoldInterface::DialectFoldInterface;
+
+ /// Registered hook to check if the given region, which is attached to an
+ /// operation that is *not* isolated from above, should be used when
+ /// materializing constants.
+ bool shouldMaterializeInto(Region *region) const final {
+ // If this is a one region operation, then insert into it.
+ return isa<OneRegionOp>(region->getParentOp());
+ }
+};
+
+/// This class defines the interface for handling inlining with standard
+/// operations.
+struct TestInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ bool isLegalToInline(Operation *call, Operation *callable,
+ bool wouldBeCloned) const final {
+ // Don't allow inlining calls that are marked `noinline`.
+ return !call->hasAttr("noinline");
+ }
+ bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
+ // Inlining into test dialect regions is legal.
+ return true;
+ }
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+
+ bool shouldAnalyzeRecursively(Operation *op) const final {
+ // Analyze recursively if this is not a functional region operation, it
+ // froms a separate functional scope.
+ return !isa<FunctionalRegionOp>(op);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Handle the given inlined terminator by replacing it with a new operation
+ /// as necessary.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const final {
+ // Only handle "test.return" here.
+ auto returnOp = dyn_cast<TestReturnOp>(op);
+ if (!returnOp)
+ return;
+
+ // Replace the values directly with the return operands.
+ assert(returnOp.getNumOperands() == valuesToRepl.size());
+ for (const auto &it : llvm::enumerate(returnOp.getOperands()))
+ valuesToRepl[it.index()].replaceAllUsesWith(it.value());
+ }
+
+ /// Attempt to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned.
+ Operation *materializeCallConversion(OpBuilder &builder, Value input,
+ Type resultType,
+ Location conversionLoc) const final {
+ // Only allow conversion for i16/i32 types.
+ if (!(resultType.isSignlessInteger(16) ||
+ resultType.isSignlessInteger(32)) ||
+ !(input.getType().isSignlessInteger(16) ||
+ input.getType().isSignlessInteger(32)))
+ return nullptr;
+ return builder.create<TestCastOp>(conversionLoc, resultType, input);
+ }
+
+ Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
+ Value argument,
+ DictionaryAttr argumentAttrs) const final {
+ if (!argumentAttrs.contains("test.handle_argument"))
+ return argument;
+ return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
+ argument);
+ }
+
+ Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
+ Value result, DictionaryAttr resultAttrs) const final {
+ if (!resultAttrs.contains("test.handle_result"))
+ return result;
+ return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
+ result);
+ }
+
+ void processInlinedCallBlocks(
+ Operation *call,
+ iterator_range<Region::iterator> inlinedBlocks) const final {
+ if (!isa<ConversionCallOp>(call))
+ return;
+
+ // Set attributed on all ops in the inlined blocks.
+ for (Block &block : inlinedBlocks) {
+ block.walk([&](Operation *op) {
+ op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
+ });
+ }
+ }
+};
+
+struct TestReductionPatternInterface : public DialectReductionPatternInterface {
+public:
+ TestReductionPatternInterface(Dialect *dialect)
+ : DialectReductionPatternInterface(dialect) {}
+
+ void populateReductionPatterns(RewritePatternSet &patterns) const final {
+ populateTestReductionPatterns(patterns);
+ }
+};
+
+} // namespace
+
+void TestDialect::registerInterfaces() {
+ auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
+ addInterface<TestOpAsmInterface>(blobInterface);
+
+ addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
+ TestReductionPatternInterface, TestBytecodeDialectInterface>();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4eb19e6dd6fe27..9f897a6a30f541 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -863,72 +863,7 @@ def AttrSizedResultCompileTestOp : TEST_Op<"attr_sized_results_compile_test",
let results = (outs Variadic<I32>:$a, I32:$b, Optional<I32>:$c);
}
-// This is used to test that the fallback for a custom op's parser and printer
-// is the dialect parser and printer hooks.
-def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
-
-// Ops related to OIList primitive
-def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
- let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword,
- UnitAttr:$
diff NameUnitAttrKeyword);
- let assemblyFormat = [{
- oilist( `keyword` $keyword
- | `otherKeyword` $otherKeyword
- | `thirdKeyword` $
diff NameUnitAttrKeyword) attr-dict
- }];
-}
-
-def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> {
- let arguments = (ins Optional<AnyType>:$arg0,
- Optional<AnyType>:$arg1,
- Optional<AnyType>:$arg2);
- let assemblyFormat = [{
- oilist( `keyword` $arg0 `:` type($arg0)
- | `otherKeyword` $arg1 `:` type($arg1)
- | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict
- }];
-}
-
-def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> {
- let arguments = (ins Variadic<AnyType>:$arg0,
- Variadic<AnyType>:$arg1,
- Variadic<AnyType>:$arg2);
- let assemblyFormat = [{
- oilist( `keyword` `(` $arg0 `:` type($arg0) `)`
- | `otherKeyword` `(` $arg1 `:` type($arg1) `)`
- | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict
- }];
-}
-
-def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
- let arguments = (ins Variadic<AnyType>:$arg0,
- Optional<I32>:$optOperand,
- UnitAttr:$nowait);
- let assemblyFormat = [{
- oilist( `private` `(` $arg0 `:` type($arg0) `)`
- | `reduction` custom<CustomOptionalOperand>($optOperand)
- | `nowait` $nowait
- ) attr-dict
- }];
-}
-
-def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> {
- let assemblyFormat = [{
- oilist( `foo` | `bar` ) `buzz` attr-dict
- }];
-}
-def TestEllipsisOp : TEST_Op<"ellipsis"> {
- let arguments = (ins Variadic<AnyType>:$operands, UnitAttr:$variadic);
- let assemblyFormat = [{
- `(` $operands (`...` $variadic^)? `)` attr-dict `:` type($operands) `...`
- }];
-}
-
-def ElseAnchorOp : TEST_Op<"else_anchor"> {
- let arguments = (ins Optional<AnyType>:$a);
- let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict";
-}
// This is used to test encoding of a string attribute into an SSA name of a
// pretty printed value name.
@@ -967,11 +902,6 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
let assemblyFormat = "regions attr-dict-with-keyword";
}
-// This is used to test that the default dialect is not elided when printing an
-// op with dots in the name to avoid parsing ambiguity.
-def OpWithDotInNameOp : TEST_Op<"op.with_dot_in_name"> {
- let assemblyFormat = "attr-dict";
-}
// This is used to test the OpAsmOpInterface::getAsmBlockName() feature:
// blocks nested in a region under this op will have a name defined by the
@@ -1997,56 +1927,6 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
let hasCustomAssemblyFormat = 1;
}
-def WrappingRegionOp : TEST_Op<"wrapping_region",
- [SingleBlockImplicitTerminator<"TestReturnOp">]> {
- let summary = "wrapping region operation";
- let description = [{
- Test op wrapping another op in a region, to test calling
- parseGenericOperation from the custom parser.
- }];
-
- let results = (outs Variadic<AnyType>);
- let regions = (region SizedRegion<1>:$region);
- let hasCustomAssemblyFormat = 1;
-}
-
-def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
- [SingleBlockImplicitTerminator<"TestReturnOp">]> {
- let summary = "pretty_printed_region operation";
- let description = [{
- Test-op can be printed either in a "pretty" or "non-pretty" way based on
- some criteria. The custom parser parsers both the versions while testing
- APIs: parseCustomOperationName & parseGenericOperationAfterOpName.
- }];
- let arguments = (ins
- AnyType:$input1,
- AnyType:$input2
- );
-
- let results = (outs AnyType);
- let regions = (region SizedRegion<1>:$region);
- let hasCustomAssemblyFormat = 1;
-}
-
-def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> {
- let summary = "polyfor operation";
- let description = [{
- Test op with multiple region arguments, each argument of index type.
- }];
- let extraClassDeclaration = [{
- void getAsmBlockArgumentNames(mlir::Region ®ion,
- mlir::OpAsmSetValueNameFn setNameFn);
- }];
- let regions = (region SizedRegion<1>:$region);
- let hasCustomAssemblyFormat = 1;
-}
-
-def TestAttrWithLoc : TEST_Op<"attr_with_loc"> {
- let summary = "op's attribute has a location";
- let arguments = (ins AnyAttr:$loc, AnyAttr:$value);
- let assemblyFormat = "`(` $value `` custom<OptionalLoc>($loc) `)` attr-dict";
-}
-
//===----------------------------------------------------------------------===//
// Test OpAsmInterface.
@@ -2059,598 +1939,6 @@ def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> {
let results = (outs AnyType);
}
-//===----------------------------------------------------------------------===//
-// Test Op Asm Format
-//===----------------------------------------------------------------------===//
-
-def FormatLiteralOp : TEST_Op<"format_literal_op"> {
- let assemblyFormat = [{
- `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)`
- `?` `+` `*` `{` `\n` `}` attr-dict
- }];
-}
-
-// Test that we elide attributes that are within the syntax.
-def FormatAttrOp : TEST_Op<"format_attr_op"> {
- let arguments = (ins I64Attr:$attr);
- let assemblyFormat = "$attr attr-dict";
-}
-
-// Test that we elide optional attributes that are within the syntax.
-def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> {
- let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
- let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict";
-}
-def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> {
- let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
- let assemblyFormat = "($opt_attr^)? attr-dict";
-}
-
-// Test that we format symbol name attributes properly.
-def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> {
- let arguments = (ins SymbolNameAttr:$attr);
- let assemblyFormat = "$attr attr-dict";
-}
-
-// Test that we format optional symbol name attributes properly.
-def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> {
- let arguments = (ins OptionalAttr<SymbolNameAttr>:$opt_attr);
- let assemblyFormat = "($opt_attr^)? attr-dict";
-}
-
-// Test that we format optional symbol reference attributes properly.
-def FormatOptSymbolRefAttrOp : TEST_Op<"format_opt_symbol_ref_attr_op"> {
- let arguments = (ins OptionalAttr<SymbolRefAttr>:$opt_attr);
- let assemblyFormat = "($opt_attr^)? attr-dict";
-}
-
-// Test that we elide attributes that are within the syntax.
-def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> {
- let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$opt_attr);
- let assemblyFormat = "attr-dict-with-keyword";
-}
-
-// Test that we don't need to provide types in the format if they are buildable.
-def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
- let arguments = (ins I64:$buildable);
- let results = (outs I64:$buildable_res);
- let assemblyFormat = "$buildable attr-dict";
-}
-
-// Test various mixings of region formatting.
-class FormatRegionBase<string suffix, string fmt>
- : TEST_Op<"format_region_" # suffix # "_op"> {
- let regions = (region AnyRegion:$region);
- let assemblyFormat = fmt;
-}
-def FormatRegionAOp : FormatRegionBase<"a", [{
- regions attr-dict
-}]>;
-def FormatRegionBOp : FormatRegionBase<"b", [{
- $region attr-dict
-}]>;
-def FormatRegionCOp : FormatRegionBase<"c", [{
- (`region` $region^)? attr-dict
-}]>;
-class FormatVariadicRegionBase<string suffix, string fmt>
- : TEST_Op<"format_variadic_region_" # suffix # "_op"> {
- let regions = (region VariadicRegion<AnyRegion>:$regions);
- let assemblyFormat = fmt;
-}
-def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
- $regions attr-dict
-}]>;
-def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
- ($regions^ `found_regions`)? attr-dict
-}]>;
-class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
- : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
- [SingleBlockImplicitTerminator<"TestReturnOp">]> {
- let regions = (region AnyRegion:$region);
- let assemblyFormat = fmt;
-}
-def FormatFormatRegionImplicitTerminatorAOp
- : FormatRegionImplicitTerminatorBase<"a", [{
- $region attr-dict
-}]>;
-
-// Test various mixings of result type formatting.
-class FormatResultBase<string suffix, string fmt>
- : TEST_Op<"format_result_" # suffix # "_op"> {
- let results = (outs I64:$buildable_res, AnyMemRef:$result);
- let assemblyFormat = fmt;
-}
-def FormatResultAOp : FormatResultBase<"a", [{
- type($result) attr-dict
-}]>;
-def FormatResultBOp : FormatResultBase<"b", [{
- type(results) attr-dict
-}]>;
-def FormatResultCOp : FormatResultBase<"c", [{
- functional-type($buildable_res, $result) attr-dict
-}]>;
-
-def FormatVariadicResult : TEST_Op<"format_variadic_result"> {
- let results = (outs Variadic<I64>:$result);
- let assemblyFormat = [{ `:` type($result) attr-dict}];
-}
-
-def FormatMultipleVariadicResults : TEST_Op<"format_multiple_variadic_results",
- [AttrSizedResultSegments]> {
- let results = (outs Variadic<I64>:$result0, Variadic<AnyType>:$result1);
- let assemblyFormat = [{
- `:` `(` type($result0) `)` `,` `(` type($result1) `)` attr-dict
- }];
-}
-
-// Test various mixings of operand type formatting.
-class FormatOperandBase<string suffix, string fmt>
- : TEST_Op<"format_operand_" # suffix # "_op"> {
- let arguments = (ins I64:$buildable, AnyMemRef:$operand);
- let assemblyFormat = fmt;
-}
-
-def FormatOperandAOp : FormatOperandBase<"a", [{
- operands `:` type(operands) attr-dict
-}]>;
-def FormatOperandBOp : FormatOperandBase<"b", [{
- operands `:` type($operand) attr-dict
-}]>;
-def FormatOperandCOp : FormatOperandBase<"c", [{
- $buildable `,` $operand `:` type(operands) attr-dict
-}]>;
-def FormatOperandDOp : FormatOperandBase<"d", [{
- $buildable `,` $operand `:` type($operand) attr-dict
-}]>;
-def FormatOperandEOp : FormatOperandBase<"e", [{
- $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
-}]>;
-
-def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
- let successors = (successor VariadicSuccessor<AnySuccessor>:$targets);
- let assemblyFormat = "$targets attr-dict";
-}
-
-def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> {
- let arguments = (ins Variadic<I64>:$operand);
- let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
-}
-def FormatVariadicOfVariadicOperand
- : TEST_Op<"format_variadic_of_variadic_operand"> {
- let arguments = (ins
- VariadicOfVariadic<I64, "operand_segments">:$operand,
- DenseI32ArrayAttr:$operand_segments
- );
- let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
-}
-
-def FormatMultipleVariadicOperands :
- TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> {
- let arguments = (ins Variadic<I64>:$operand0, Variadic<AnyType>:$operand1);
- let assemblyFormat = [{
- ` ` `(` $operand0 `)` `,` `(` $operand1 `:` type($operand1) `)` attr-dict
- }];
-}
-
-// Test various mixings of optional operand and result type formatting.
-class FormatOptionalOperandResultOpBase<string suffix, string fmt>
- : TEST_Op<"format_optional_operand_result_" # suffix # "_op",
- [AttrSizedOperandSegments]> {
- let arguments = (ins Optional<I64>:$optional, Variadic<I64>:$variadic);
- let results = (outs Optional<I64>:$optional_res);
- let assemblyFormat = fmt;
-}
-
-def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{
- `(` $optional `:` type($optional) `)` `:` type($optional_res)
- (`[` $variadic^ `]`)? attr-dict
-}]>;
-
-def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
- (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res)
- (`[` $variadic^ `]`)? attr-dict
-}]>;
-
-// Test optional result type formatting.
-class FormatOptionalResultOpBase<string suffix, string fmt>
- : TEST_Op<"format_optional_result_" # suffix # "_op",
- [AttrSizedResultSegments]> {
- let results = (outs Optional<I64>:$optional, Variadic<I64>:$variadic);
- let assemblyFormat = fmt;
-}
-def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{
- (`:` type($optional)^ `->` type($variadic))? attr-dict
-}]>;
-
-def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{
- (`:` type($optional) `->` type($variadic)^)? attr-dict
-}]>;
-
-def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{
- (`:` functional-type($optional, $variadic)^)? attr-dict
-}]>;
-
-def FormatOptionalResultDOp
- : TEST_Op<"format_optional_result_d_op" > {
- let results = (outs Optional<F80>:$optional);
- let assemblyFormat = "(`:` type($optional)^)? attr-dict";
-}
-
-def FormatTwoVariadicOperandsNoBuildableTypeOp
- : TEST_Op<"format_two_variadic_operands_no_buildable_type_op",
- [AttrSizedOperandSegments]> {
- let arguments = (ins Variadic<AnyType>:$a,
- Variadic<AnyType>:$b);
- let assemblyFormat = [{
- `(` $a `:` type($a) `)` `->` `(` $b `:` type($b) `)` attr-dict
- }];
-}
-
-def FormatInferVariadicTypeFromNonVariadic
- : TEST_Op<"format_infer_variadic_type_from_non_variadic",
- [SameOperandsAndResultType]> {
- let arguments = (ins Variadic<AnyType>:$args);
- let results = (outs AnyType:$result);
- let assemblyFormat = "operands attr-dict `:` type($result)";
-}
-
-def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> {
- let arguments = (ins UnitAttr:$is_optional);
- let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict";
-}
-
-def FormatOptionalUnitAttrNoElide
- : TEST_Op<"format_optional_unit_attribute_no_elide"> {
- let arguments = (ins UnitAttr:$is_optional);
- let assemblyFormat = "($is_optional^)? attr-dict";
-}
-
-def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
- let arguments = (ins OptionalAttr<SomeI64Enum>:$attr);
- let assemblyFormat = "($attr^)? attr-dict";
-}
-
-def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> {
- let arguments = (ins DefaultValuedStrAttr<StrAttr, "default">:$str,
- DefaultValuedStrAttr<SymbolNameAttr, "default">:$sym,
- DefaultValuedAttr<SomeI64Enum, "SomeI64Enum::case5">:$e);
- let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict";
-}
-
-def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
- let arguments = (ins UnitAttr:$isFirstBranchPresent);
- let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
-}
-
-def FormatCompoundAttr : TEST_Op<"format_compound_attr"> {
- let arguments = (ins CompoundAttrA:$compound);
- let assemblyFormat = "$compound attr-dict-with-keyword";
-}
-
-def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
- let arguments = (ins CompoundAttrNested:$nested);
- let assemblyFormat = "$nested attr-dict-with-keyword";
-}
-
-def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
- let arguments = (ins CompoundNestedOuter:$nested);
- let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
-}
-
-def FormatMaybeEmptyType : TEST_Op<"format_maybe_empty_type"> {
- let arguments = (ins TestTypeOptionalValueType:$in);
- let assemblyFormat = "$in `:` type($in) attr-dict";
-}
-
-def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> {
- let arguments = (ins CompoundNestedOuter:$nested);
- let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword";
-}
-
-def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
- let arguments = (ins CompoundNestedOuterType:$nested);
- let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
-}
-
-def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> {
- let arguments = (ins CompoundNestedOuterType:$nested);
- let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword";
-}
-
-//===----------------------------------------------------------------------===//
-// Custom Directives
-
-def FormatCustomDirectiveOperands
- : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
- let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
- Variadic<I64>:$varOperands);
- let assemblyFormat = [{
- custom<CustomDirectiveOperands>(
- $operand, $optOperand, $varOperands
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveOperandsAndTypes
- : TEST_Op<"format_custom_directive_operands_and_types",
- [AttrSizedOperandSegments]> {
- let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
- Variadic<AnyType>:$varOperands);
- let assemblyFormat = [{
- custom<CustomDirectiveOperandsAndTypes>(
- $operand, $optOperand, $varOperands,
- type($operand), type($optOperand), type($varOperands)
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
- let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$other_regions);
- let assemblyFormat = [{
- custom<CustomDirectiveRegions>(
- $region, $other_regions
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveResults
- : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
- let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
- Variadic<AnyType>:$varResults);
- let assemblyFormat = [{
- custom<CustomDirectiveResults>(
- type($result), type($optResult), type($varResults)
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveResultsWithTypeRefs
- : TEST_Op<"format_custom_directive_results_with_type_refs",
- [AttrSizedResultSegments]> {
- let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
- Variadic<AnyType>:$varResults);
- let assemblyFormat = [{
- custom<CustomDirectiveResults>(
- type($result), type($optResult), type($varResults)
- )
- custom<CustomDirectiveWithTypeRefs>(
- ref(type($result)), ref(type($optResult)), ref(type($varResults))
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveWithOptionalOperandRef
- : TEST_Op<"format_custom_directive_with_optional_operand_ref"> {
- let arguments = (ins Optional<I64>:$optOperand);
- let assemblyFormat = [{
- ($optOperand^)? `:`
- custom<CustomDirectiveOptionalOperandRef>(ref($optOperand))
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveSuccessors
- : TEST_Op<"format_custom_directive_successors", [Terminator]> {
- let successors = (successor AnySuccessor:$successor,
- VariadicSuccessor<AnySuccessor>:$successors);
- let assemblyFormat = [{
- custom<CustomDirectiveSuccessors>(
- $successor, $successors
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveAttributes
- : TEST_Op<"format_custom_directive_attributes"> {
- let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
- let assemblyFormat = [{
- custom<CustomDirectiveAttributes>(
- $attr, $optAttr
- )
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveSpacing
- : TEST_Op<"format_custom_directive_spacing"> {
- let arguments = (ins StrAttr:$attr1, StrAttr:$attr2);
- let assemblyFormat = [{
- custom<CustomDirectiveSpacing>($attr1)
- custom<CustomDirectiveSpacing>($attr2)
- attr-dict
- }];
-}
-
-def FormatCustomDirectiveAttrDict
- : TEST_Op<"format_custom_directive_attrdict"> {
- let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
- let assemblyFormat = [{
- custom<CustomDirectiveAttrDict>( attr-dict )
- }];
-}
-
-def FormatLiteralFollowingOptionalGroup
- : TEST_Op<"format_literal_following_optional_group"> {
- let arguments = (ins TypeAttr:$type, OptionalAttr<AnyAttr>:$value);
- let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict";
-}
-
-//===----------------------------------------------------------------------===//
-// AllTypesMatch type inference
-
-def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
- AllTypesMatch<["value1", "value2", "result"]>
- ]> {
- let arguments = (ins AnyType:$value1, AnyType:$value2);
- let results = (outs AnyType:$result);
- let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
-}
-
-def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
- AllTypesMatch<["value1", "value2", "result"]>
- ]> {
- let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2);
- let results = (outs AnyType:$result);
- let assemblyFormat = "attr-dict $value1 `,` $value2";
-}
-
-//===----------------------------------------------------------------------===//
-// TypesMatchWith type inference
-
-def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
- TypesMatchWith<"result type matches operand", "value", "result", "$_self">
- ]> {
- let arguments = (ins AnyType:$value);
- let results = (outs AnyType:$result);
- let assemblyFormat = "attr-dict $value `:` type($value)";
-}
-
-def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
- RangedTypesMatchWith<"result type matches operand", "value", "result",
- "llvm::make_range($_self.begin(), $_self.end())">
- ]> {
- let arguments = (ins Variadic<AnyType>:$value);
- let results = (outs Variadic<AnyType>:$result);
- let assemblyFormat = "attr-dict $value `:` type($value)";
-}
-
-def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
- TypesMatchWith<"result type matches constant", "value", "result", "$_self">
- ]> {
- let arguments = (ins TypedAttrInterface:$value);
- let results = (outs AnyType:$result);
- let assemblyFormat = "attr-dict $value";
-}
-
-def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
- TypesMatchWith<"tuple result type matches operand type", "value", "result",
- "::mlir::TupleType::get($_ctxt, $_self)">
- ]> {
- let arguments = (ins AnyType:$value);
- let results = (outs AnyType:$result);
- let assemblyFormat = "attr-dict $value `:` type($value)";
-}
-
-//===----------------------------------------------------------------------===//
-// InferTypeOpInterface type inference in assembly format
-
-def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
- let results = (outs AnyType);
- let assemblyFormat = "attr-dict";
-
- let extraClassDeclaration = [{
- static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
- ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
- ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
- return ::mlir::success();
- }
- }];
-}
-
-// Check that formatget supports DeclareOpInterfaceMethods.
-def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
- let results = (outs AnyType);
- let assemblyFormat = "attr-dict";
-}
-
-// Base class for testing mixing allOperandTypes, allOperands, and
-// inferResultTypes.
-class FormatInferAllTypesBaseOp<string mnemonic, list<Trait> traits = []>
- : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
- let arguments = (ins Variadic<AnyType>:$args);
- let results = (outs Variadic<AnyType>:$outs);
- let extraClassDeclaration = [{
- static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
- ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
- ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- ::mlir::TypeRange operandTypes = operands.getTypes();
- inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
- return ::mlir::success();
- }
- }];
-}
-
-// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
-def FormatInferTypeAllOperandsAndTypesOp
- : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
- let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
-}
-
-// Test inferReturnTypes is called when allOperandTypes is true and there is one
-// ODS operand.
-def FormatInferTypeAllOperandsAndTypesOneOperandOp
- : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
- let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
-}
-
-// Test inferReturnTypes is called when allOperandTypes is true and there are
-// more than one ODS operands.
-def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
- : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
- [SameVariadicOperandSize]> {
- let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
- let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
-}
-
-// Test inferReturnTypes is called when allOperands is true and operand types
-// are separately specified.
-def FormatInferTypeAllTypesOp
- : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
- let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
-}
-
-// Test inferReturnTypes coupled with regions.
-def FormatInferTypeRegionsOp
- : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> {
- let results = (outs Variadic<AnyType>:$outs);
- let regions = (region AnyRegion:$region);
- let assemblyFormat = "$region attr-dict";
- let extraClassDeclaration = [{
- static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
- ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
- ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- if (regions.empty())
- return ::mlir::failure();
- auto types = regions.front()->getArgumentTypes();
- inferredReturnTypes.assign(types.begin(), types.end());
- return ::mlir::success();
- }
- }];
-}
-
-// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes).
-def FormatInferTypeVariadicOperandsOp
- : TEST_Op<"format_infer_type_variadic_operands",
- [InferTypeOpInterface, AttrSizedOperandSegments]> {
- let arguments = (ins Variadic<I32>:$a, Variadic<I64>:$b);
- let results = (outs Variadic<AnyType>:$outs);
- let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict";
- let extraClassDeclaration = [{
- static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
- ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
- ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- FormatInferTypeVariadicOperandsOpAdaptor adaptor(
- operands, attributes, *properties.as<Properties *>(), {});
- auto aTypes = adaptor.getA().getTypes();
- auto bTypes = adaptor.getB().getTypes();
- inferredReturnTypes.append(aTypes.begin(), aTypes.end());
- inferredReturnTypes.append(bTypes.begin(), bTypes.end());
- return ::mlir::success();
- }
- }];
-}
-
//===----------------------------------------------------------------------===//
// Test ArrayOfAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
new file mode 100644
index 00000000000000..84e6a43655cacd
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -0,0 +1,494 @@
+//===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOpsSyntax.h"
+#include "TestDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/Support/Base64.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// Test Format* operations
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Parsing
+
+static ParseResult parseCustomOptionalOperand(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ if (succeeded(parser.parseOptionalLParen())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand) || parser.parseRParen())
+ return failure();
+ }
+ return success();
+}
+
+static ParseResult parseCustomDirectiveOperands(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
+ if (parser.parseOperand(operand))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseOperandList(varOperands) || parser.parseRParen())
+ return failure();
+ return success();
+}
+static ParseResult
+parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+ Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(operandType))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseType(optOperandType))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+ return failure();
+ return success();
+}
+static ParseResult
+parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
+ Type optOperandType,
+ const SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseKeyword("type_refs_capture"))
+ return failure();
+
+ Type operandType2, optOperandType2;
+ SmallVector<Type, 1> varOperandTypes2;
+ if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+ varOperandTypes2))
+ return failure();
+
+ if (operandType != operandType2 || optOperandType != optOperandType2 ||
+ varOperandTypes != varOperandTypes2)
+ return failure();
+
+ return success();
+}
+static ParseResult parseCustomDirectiveOperandsAndTypes(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
+ Type &operandType, Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+ parseCustomDirectiveResults(parser, operandType, optOperandType,
+ varOperandTypes))
+ return failure();
+ return success();
+}
+static ParseResult parseCustomDirectiveRegions(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+ if (parser.parseRegion(region))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+ if (parser.parseRegion(*varRegion))
+ return failure();
+ varRegions.emplace_back(std::move(varRegion));
+ return success();
+}
+static ParseResult
+parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+ SmallVectorImpl<Block *> &varSuccessors) {
+ if (parser.parseSuccessor(successor))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ Block *varSuccessor;
+ if (parser.parseSuccessor(varSuccessor))
+ return failure();
+ varSuccessors.append(2, varSuccessor);
+ return success();
+}
+static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
+ IntegerAttr &attr,
+ IntegerAttr &optAttr) {
+ if (parser.parseAttribute(attr))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseAttribute(optAttr))
+ return failure();
+ }
+ return success();
+}
+static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
+ mlir::StringAttr &attr) {
+ return parser.parseAttribute(attr);
+}
+static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
+ NamedAttrList &attrs) {
+ return parser.parseOptionalAttrDict(attrs);
+}
+static ParseResult parseCustomDirectiveOptionalOperandRef(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ int64_t operandCount = 0;
+ if (parser.parseInteger(operandCount))
+ return failure();
+ bool expectedOptionalOperand = operandCount == 0;
+ return success(expectedOptionalOperand != optOperand.has_value());
+}
+
+//===----------------------------------------------------------------------===//
+// Printing
+
+static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
+ Value optOperand) {
+ if (optOperand)
+ printer << "(" << optOperand << ") ";
+}
+
+static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+ Value operand, Value optOperand,
+ OperandRange varOperands) {
+ printer << operand;
+ if (optOperand)
+ printer << ", " << optOperand;
+ printer << " -> (" << varOperands << ")";
+}
+static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+ Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " : " << operandType;
+ if (optOperandType)
+ printer << ", " << optOperandType;
+ printer << " -> (" << varOperandTypes << ")";
+}
+static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+ Operation *op, Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " type_refs_capture ";
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+static void printCustomDirectiveOperandsAndTypes(
+ OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+ OperandRange varOperands, Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+ Region ®ion,
+ MutableArrayRef<Region> varRegions) {
+ printer.printRegion(region);
+ if (!varRegions.empty()) {
+ printer << ", ";
+ for (Region ®ion : varRegions)
+ printer.printRegion(region);
+ }
+}
+static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
+ Block *successor,
+ SuccessorRange varSuccessors) {
+ printer << successor;
+ if (!varSuccessors.empty())
+ printer << ", " << varSuccessors.front();
+}
+static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
+ Attribute attribute,
+ Attribute optAttribute) {
+ printer << attribute;
+ if (optAttribute)
+ printer << ", " << optAttribute;
+}
+static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
+ Attribute attribute) {
+ printer << attribute;
+}
+static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+ DictionaryAttr attrs) {
+ printer.printOptionalAttrDict(attrs.getValue());
+}
+
+static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
+ Operation *op,
+ Value optOperand) {
+ printer << (optOperand ? "1" : "0");
+}
+//===----------------------------------------------------------------------===//
+// Test parser.
+//===----------------------------------------------------------------------===//
+
+ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalColon())
+ return success();
+ uint64_t numResults;
+ if (parser.parseInteger(numResults))
+ return failure();
+
+ IndexType type = parser.getBuilder().getIndexType();
+ for (unsigned i = 0; i < numResults; ++i)
+ result.addTypes(type);
+ return success();
+}
+
+void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
+ if (unsigned numResults = getNumResults())
+ p << " : " << numResults;
+}
+
+ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ StringRef keyword;
+ if (parser.parseKeyword(&keyword))
+ return failure();
+ result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
+ return success();
+}
+
+void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
+
+ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ std::vector<char> bytes;
+ if (parser.parseBase64Bytes(&bytes))
+ return failure();
+ result.addAttribute("b64", parser.getBuilder().getStringAttr(
+ StringRef(&bytes.front(), bytes.size())));
+ return success();
+}
+
+void ParseB64BytesOp::print(OpAsmPrinter &p) {
+ p << " \"" << llvm::encodeBase64(getB64()) << "\"";
+}
+
+::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
+ ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+ return ::mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
+
+ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseKeyword("wraps"))
+ return failure();
+
+ // Parse the wrapped op in a region
+ Region &body = *result.addRegion();
+ body.push_back(new Block);
+ Block &block = body.back();
+ Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
+ if (!wrappedOp)
+ return failure();
+
+ // Create a return terminator in the inner region, pass as operand to the
+ // terminator the returned values from the wrapped operation.
+ SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
+ OpBuilder builder(parser.getContext());
+ builder.setInsertionPointToEnd(&block);
+ builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
+
+ // Get the results type for the wrapping op from the terminator operands.
+ Operation &returnOp = body.back().back();
+ result.types.append(returnOp.operand_type_begin(),
+ returnOp.operand_type_end());
+
+ // Use the location of the wrapped op for the "test.wrapping_region" op.
+ result.location = wrappedOp->getLoc();
+
+ return success();
+}
+
+void WrappingRegionOp::print(OpAsmPrinter &p) {
+ p << " wraps ";
+ p.printGenericOp(&getRegion().front().front());
+}
+
+//===----------------------------------------------------------------------===//
+// Test PrettyPrintedRegionOp - exercising the following parser APIs
+// parseGenericOperationAfterOpName
+// parseCustomOperationName
+//===----------------------------------------------------------------------===//
+
+ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+
+ SMLoc loc = parser.getCurrentLocation();
+ Location currLocation = parser.getEncodedSourceLoc(loc);
+
+ // Parse the operands.
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
+ if (parser.parseOperandList(operands))
+ return failure();
+
+ // Check if we are parsing the pretty-printed version
+ // test.pretty_printed_region start <inner-op> end : <functional-type>
+ // Else fallback to parsing the "non pretty-printed" version.
+ if (!succeeded(parser.parseOptionalKeyword("start")))
+ return parser.parseGenericOperationAfterOpName(result,
+ llvm::ArrayRef(operands));
+
+ FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
+ if (failed(parseOpNameInfo))
+ return failure();
+
+ StringAttr innerOpName = parseOpNameInfo->getIdentifier();
+
+ FunctionType opFntype;
+ std::optional<Location> explicitLoc;
+ if (parser.parseKeyword("end") || parser.parseColon() ||
+ parser.parseType(opFntype) ||
+ parser.parseOptionalLocationSpecifier(explicitLoc))
+ return failure();
+
+ // If location of the op is explicitly provided, then use it; Else use
+ // the parser's current location.
+ Location opLoc = explicitLoc.value_or(currLocation);
+
+ // Derive the SSA-values for op's operands.
+ if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
+ result.operands))
+ return failure();
+
+ // Add a region for op.
+ Region ®ion = *result.addRegion();
+
+ // Create a basic-block inside op's region.
+ Block &block = region.emplaceBlock();
+
+ // Create and insert an "inner-op" operation in the block.
+ // Just for testing purposes, we can assume that inner op is a binary op with
+ // result and operand types all same as the test-op's first operand.
+ Type innerOpType = opFntype.getInput(0);
+ Value lhs = block.addArgument(innerOpType, opLoc);
+ Value rhs = block.addArgument(innerOpType, opLoc);
+
+ OpBuilder builder(parser.getBuilder().getContext());
+ builder.setInsertionPointToStart(&block);
+
+ Operation *innerOp =
+ builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
+
+ // Insert a return statement in the block returning the inner-op's result.
+ builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
+
+ // Populate the op operation-state with result-type and location.
+ result.addTypes(opFntype.getResults());
+ result.location = innerOp->getLoc();
+
+ return success();
+}
+
+void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperands(getOperands());
+
+ Operation &innerOp = getRegion().front().front();
+ // Assuming that region has a single non-terminator inner-op, if the inner-op
+ // meets some criteria (which in this case is a simple one based on the name
+ // of inner-op), then we can print the entire region in a succinct way.
+ // Here we assume that the prototype of "test.special.op" can be trivially
+ // derived while parsing it back.
+ if (innerOp.getName().getStringRef().equals("test.special.op")) {
+ p << " start test.special.op end";
+ } else {
+ p << " (";
+ p.printRegion(getRegion());
+ p << ")";
+ }
+
+ p << " : ";
+ p.printFunctionalType(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// Test PolyForOp - parse list of region arguments.
+//===----------------------------------------------------------------------===//
+
+ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::Argument, 4> ivsInfo;
+ // Parse list of region arguments without a delimiter.
+ if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
+ return failure();
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ for (auto &iv : ivsInfo)
+ iv.type = parser.getBuilder().getIndexType();
+ return parser.parseRegion(*body, ivsInfo);
+}
+
+void PolyForOp::print(OpAsmPrinter &p) {
+ p << " ";
+ llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
+ p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
+ });
+ p << " ";
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
+ OpAsmSetValueNameFn setNameFn) {
+ auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
+ if (!arrayAttr)
+ return;
+ auto args = getRegion().front().getArguments();
+ auto e = std::min(arrayAttr.size(), args.size());
+ for (unsigned i = 0; i < e; ++i) {
+ if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
+ setNameFn(args[i], strAttr.getValue());
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// TestAttrWithLoc - parse/printOptionalLocationSpecifier
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
+ std::optional<Location> result;
+ SMLoc sourceLoc = p.getCurrentLocation();
+ if (p.parseOptionalLocationSpecifier(result))
+ return failure();
+ if (result)
+ loc = *result;
+ else
+ loc = p.getEncodedSourceLoc(sourceLoc);
+ return success();
+}
+
+static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
+ p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
+}
+
+#define GET_OP_CLASSES
+#include "TestOpsSyntax.cpp.inc"
+
+void TestDialect::registerOpsSyntax() {
+ addOperations<
+#define GET_OP_LIST
+#include "TestOpsSyntax.cpp.inc"
+ >();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.h b/mlir/test/lib/Dialect/Test/TestOpsSyntax.h
new file mode 100644
index 00000000000000..7f32a3bae6c5f5
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.h
@@ -0,0 +1,26 @@
+//===- TestOpsSyntax.h - Operations for testing syntax ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H
+#define MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H
+
+#include "TestAttributes.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
+namespace test {
+class TestReturnOp;
+} // namespace test
+
+#define GET_OP_CLASSES
+#include "TestOpsSyntax.h.inc"
+
+#endif // MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
new file mode 100644
index 00000000000000..6d9accd72493ac
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
@@ -0,0 +1,741 @@
+
+//===-- TestOpsSyntax.td - Operations for testing syntax ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_OPS_SYNTAX
+#define TEST_OPS_SYNTAX
+
+include "TestAttrDefs.td"
+include "TestDialect.td"
+include "TestTypeDefs.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/IR/OpBase.td"
+
+class TEST_Op<string mnemonic, list<Trait> traits = []> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+def WrappingRegionOp : TEST_Op<"wrapping_region",
+ [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+ let summary = "wrapping region operation";
+ let description = [{
+ Test op wrapping another op in a region, to test calling
+ parseGenericOperation from the custom parser.
+ }];
+
+ let results = (outs Variadic<AnyType>);
+ let regions = (region SizedRegion<1>:$region);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
+ [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+ let summary = "pretty_printed_region operation";
+ let description = [{
+ Test-op can be printed either in a "pretty" or "non-pretty" way based on
+ some criteria. The custom parser parsers both the versions while testing
+ APIs: parseCustomOperationName & parseGenericOperationAfterOpName.
+ }];
+ let arguments = (ins
+ AnyType:$input1,
+ AnyType:$input2
+ );
+
+ let results = (outs AnyType);
+ let regions = (region SizedRegion<1>:$region);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> {
+ let summary = "polyfor operation";
+ let description = [{
+ Test op with multiple region arguments, each argument of index type.
+ }];
+ let extraClassDeclaration = [{
+ void getAsmBlockArgumentNames(mlir::Region ®ion,
+ mlir::OpAsmSetValueNameFn setNameFn);
+ }];
+ let regions = (region SizedRegion<1>:$region);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def TestAttrWithLoc : TEST_Op<"attr_with_loc"> {
+ let summary = "op's attribute has a location";
+ let arguments = (ins AnyAttr:$loc, AnyAttr:$value);
+ let assemblyFormat = "`(` $value `` custom<OptionalLoc>($loc) `)` attr-dict";
+}
+
+// -----
+
+// This is used to test that the fallback for a custom op's parser and printer
+// is the dialect parser and printer hooks.
+def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
+
+// Ops related to OIList primitive
+def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
+ let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword,
+ UnitAttr:$
diff NameUnitAttrKeyword);
+ let assemblyFormat = [{
+ oilist( `keyword` $keyword
+ | `otherKeyword` $otherKeyword
+ | `thirdKeyword` $
diff NameUnitAttrKeyword) attr-dict
+ }];
+}
+
+def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> {
+ let arguments = (ins Optional<AnyType>:$arg0,
+ Optional<AnyType>:$arg1,
+ Optional<AnyType>:$arg2);
+ let assemblyFormat = [{
+ oilist( `keyword` $arg0 `:` type($arg0)
+ | `otherKeyword` $arg1 `:` type($arg1)
+ | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict
+ }];
+}
+
+def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> {
+ let arguments = (ins Variadic<AnyType>:$arg0,
+ Variadic<AnyType>:$arg1,
+ Variadic<AnyType>:$arg2);
+ let assemblyFormat = [{
+ oilist( `keyword` `(` $arg0 `:` type($arg0) `)`
+ | `otherKeyword` `(` $arg1 `:` type($arg1) `)`
+ | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict
+ }];
+}
+
+def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
+ let arguments = (ins Variadic<AnyType>:$arg0,
+ Optional<I32>:$optOperand,
+ UnitAttr:$nowait);
+ let assemblyFormat = [{
+ oilist( `private` `(` $arg0 `:` type($arg0) `)`
+ | `reduction` custom<CustomOptionalOperand>($optOperand)
+ | `nowait` $nowait
+ ) attr-dict
+ }];
+}
+
+def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> {
+ let assemblyFormat = [{
+ oilist( `foo` | `bar` ) `buzz` attr-dict
+ }];
+}
+
+def TestEllipsisOp : TEST_Op<"ellipsis"> {
+ let arguments = (ins Variadic<AnyType>:$operands, UnitAttr:$variadic);
+ let assemblyFormat = [{
+ `(` $operands (`...` $variadic^)? `)` attr-dict `:` type($operands) `...`
+ }];
+}
+
+def ElseAnchorOp : TEST_Op<"else_anchor"> {
+ let arguments = (ins Optional<AnyType>:$a);
+ let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict";
+}
+
+// This is used to test that the default dialect is not elided when printing an
+// op with dots in the name to avoid parsing ambiguity.
+def OpWithDotInNameOp : TEST_Op<"op.with_dot_in_name"> {
+ let assemblyFormat = "attr-dict";
+}
+
+// --------------
+
+//===----------------------------------------------------------------------===//
+// Test Op Asm Format
+//===----------------------------------------------------------------------===//
+
+def FormatLiteralOp : TEST_Op<"format_literal_op"> {
+ let assemblyFormat = [{
+ `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)`
+ `?` `+` `*` `{` `\n` `}` attr-dict
+ }];
+}
+
+// Test that we elide attributes that are within the syntax.
+def FormatAttrOp : TEST_Op<"format_attr_op"> {
+ let arguments = (ins I64Attr:$attr);
+ let assemblyFormat = "$attr attr-dict";
+}
+
+// Test that we elide optional attributes that are within the syntax.
+def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> {
+ let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
+ let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict";
+}
+def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> {
+ let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
+ let assemblyFormat = "($opt_attr^)? attr-dict";
+}
+
+// Test that we format symbol name attributes properly.
+def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> {
+ let arguments = (ins SymbolNameAttr:$attr);
+ let assemblyFormat = "$attr attr-dict";
+}
+
+// Test that we format optional symbol name attributes properly.
+def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> {
+ let arguments = (ins OptionalAttr<SymbolNameAttr>:$opt_attr);
+ let assemblyFormat = "($opt_attr^)? attr-dict";
+}
+
+// Test that we format optional symbol reference attributes properly.
+def FormatOptSymbolRefAttrOp : TEST_Op<"format_opt_symbol_ref_attr_op"> {
+ let arguments = (ins OptionalAttr<SymbolRefAttr>:$opt_attr);
+ let assemblyFormat = "($opt_attr^)? attr-dict";
+}
+
+// Test that we elide attributes that are within the syntax.
+def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> {
+ let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$opt_attr);
+ let assemblyFormat = "attr-dict-with-keyword";
+}
+
+// Test that we don't need to provide types in the format if they are buildable.
+def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
+ let arguments = (ins I64:$buildable);
+ let results = (outs I64:$buildable_res);
+ let assemblyFormat = "$buildable attr-dict";
+}
+
+// Test various mixings of region formatting.
+class FormatRegionBase<string suffix, string fmt>
+ : TEST_Op<"format_region_" # suffix # "_op"> {
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = fmt;
+}
+def FormatRegionAOp : FormatRegionBase<"a", [{
+ regions attr-dict
+}]>;
+def FormatRegionBOp : FormatRegionBase<"b", [{
+ $region attr-dict
+}]>;
+def FormatRegionCOp : FormatRegionBase<"c", [{
+ (`region` $region^)? attr-dict
+}]>;
+class FormatVariadicRegionBase<string suffix, string fmt>
+ : TEST_Op<"format_variadic_region_" # suffix # "_op"> {
+ let regions = (region VariadicRegion<AnyRegion>:$regions);
+ let assemblyFormat = fmt;
+}
+def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
+ $regions attr-dict
+}]>;
+def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
+ ($regions^ `found_regions`)? attr-dict
+}]>;
+class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
+ : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
+ [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = fmt;
+}
+def FormatFormatRegionImplicitTerminatorAOp
+ : FormatRegionImplicitTerminatorBase<"a", [{
+ $region attr-dict
+}]>;
+
+// Test various mixings of result type formatting.
+class FormatResultBase<string suffix, string fmt>
+ : TEST_Op<"format_result_" # suffix # "_op"> {
+ let results = (outs I64:$buildable_res, AnyMemRef:$result);
+ let assemblyFormat = fmt;
+}
+def FormatResultAOp : FormatResultBase<"a", [{
+ type($result) attr-dict
+}]>;
+def FormatResultBOp : FormatResultBase<"b", [{
+ type(results) attr-dict
+}]>;
+def FormatResultCOp : FormatResultBase<"c", [{
+ functional-type($buildable_res, $result) attr-dict
+}]>;
+
+def FormatVariadicResult : TEST_Op<"format_variadic_result"> {
+ let results = (outs Variadic<I64>:$result);
+ let assemblyFormat = [{ `:` type($result) attr-dict}];
+}
+
+def FormatMultipleVariadicResults : TEST_Op<"format_multiple_variadic_results",
+ [AttrSizedResultSegments]> {
+ let results = (outs Variadic<I64>:$result0, Variadic<AnyType>:$result1);
+ let assemblyFormat = [{
+ `:` `(` type($result0) `)` `,` `(` type($result1) `)` attr-dict
+ }];
+}
+
+// Test various mixings of operand type formatting.
+class FormatOperandBase<string suffix, string fmt>
+ : TEST_Op<"format_operand_" # suffix # "_op"> {
+ let arguments = (ins I64:$buildable, AnyMemRef:$operand);
+ let assemblyFormat = fmt;
+}
+
+def FormatOperandAOp : FormatOperandBase<"a", [{
+ operands `:` type(operands) attr-dict
+}]>;
+def FormatOperandBOp : FormatOperandBase<"b", [{
+ operands `:` type($operand) attr-dict
+}]>;
+def FormatOperandCOp : FormatOperandBase<"c", [{
+ $buildable `,` $operand `:` type(operands) attr-dict
+}]>;
+def FormatOperandDOp : FormatOperandBase<"d", [{
+ $buildable `,` $operand `:` type($operand) attr-dict
+}]>;
+def FormatOperandEOp : FormatOperandBase<"e", [{
+ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
+}]>;
+
+def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
+ let successors = (successor VariadicSuccessor<AnySuccessor>:$targets);
+ let assemblyFormat = "$targets attr-dict";
+}
+
+def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> {
+ let arguments = (ins Variadic<I64>:$operand);
+ let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
+}
+def FormatVariadicOfVariadicOperand
+ : TEST_Op<"format_variadic_of_variadic_operand"> {
+ let arguments = (ins
+ VariadicOfVariadic<I64, "operand_segments">:$operand,
+ DenseI32ArrayAttr:$operand_segments
+ );
+ let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
+}
+
+def FormatMultipleVariadicOperands :
+ TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> {
+ let arguments = (ins Variadic<I64>:$operand0, Variadic<AnyType>:$operand1);
+ let assemblyFormat = [{
+ ` ` `(` $operand0 `)` `,` `(` $operand1 `:` type($operand1) `)` attr-dict
+ }];
+}
+
+// Test various mixings of optional operand and result type formatting.
+class FormatOptionalOperandResultOpBase<string suffix, string fmt>
+ : TEST_Op<"format_optional_operand_result_" # suffix # "_op",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins Optional<I64>:$optional, Variadic<I64>:$variadic);
+ let results = (outs Optional<I64>:$optional_res);
+ let assemblyFormat = fmt;
+}
+
+def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{
+ `(` $optional `:` type($optional) `)` `:` type($optional_res)
+ (`[` $variadic^ `]`)? attr-dict
+}]>;
+
+def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
+ (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res)
+ (`[` $variadic^ `]`)? attr-dict
+}]>;
+
+// Test optional result type formatting.
+class FormatOptionalResultOpBase<string suffix, string fmt>
+ : TEST_Op<"format_optional_result_" # suffix # "_op",
+ [AttrSizedResultSegments]> {
+ let results = (outs Optional<I64>:$optional, Variadic<I64>:$variadic);
+ let assemblyFormat = fmt;
+}
+def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{
+ (`:` type($optional)^ `->` type($variadic))? attr-dict
+}]>;
+
+def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{
+ (`:` type($optional) `->` type($variadic)^)? attr-dict
+}]>;
+
+def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{
+ (`:` functional-type($optional, $variadic)^)? attr-dict
+}]>;
+
+def FormatOptionalResultDOp
+ : TEST_Op<"format_optional_result_d_op" > {
+ let results = (outs Optional<F80>:$optional);
+ let assemblyFormat = "(`:` type($optional)^)? attr-dict";
+}
+
+def FormatTwoVariadicOperandsNoBuildableTypeOp
+ : TEST_Op<"format_two_variadic_operands_no_buildable_type_op",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins Variadic<AnyType>:$a,
+ Variadic<AnyType>:$b);
+ let assemblyFormat = [{
+ `(` $a `:` type($a) `)` `->` `(` $b `:` type($b) `)` attr-dict
+ }];
+}
+
+def FormatInferVariadicTypeFromNonVariadic
+ : TEST_Op<"format_infer_variadic_type_from_non_variadic",
+ [SameOperandsAndResultType]> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> {
+ let arguments = (ins UnitAttr:$is_optional);
+ let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict";
+}
+
+def FormatOptionalUnitAttrNoElide
+ : TEST_Op<"format_optional_unit_attribute_no_elide"> {
+ let arguments = (ins UnitAttr:$is_optional);
+ let assemblyFormat = "($is_optional^)? attr-dict";
+}
+
+def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
+ let arguments = (ins OptionalAttr<SomeI64Enum>:$attr);
+ let assemblyFormat = "($attr^)? attr-dict";
+}
+
+def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> {
+ let arguments = (ins DefaultValuedStrAttr<StrAttr, "default">:$str,
+ DefaultValuedStrAttr<SymbolNameAttr, "default">:$sym,
+ DefaultValuedAttr<SomeI64Enum, "SomeI64Enum::case5">:$e);
+ let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict";
+}
+
+def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
+ let arguments = (ins UnitAttr:$isFirstBranchPresent);
+ let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
+}
+
+def FormatCompoundAttr : TEST_Op<"format_compound_attr"> {
+ let arguments = (ins CompoundAttrA:$compound);
+ let assemblyFormat = "$compound attr-dict-with-keyword";
+}
+
+def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
+ let arguments = (ins CompoundAttrNested:$nested);
+ let assemblyFormat = "$nested attr-dict-with-keyword";
+}
+
+def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
+ let arguments = (ins CompoundNestedOuter:$nested);
+ let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
+}
+
+def FormatMaybeEmptyType : TEST_Op<"format_maybe_empty_type"> {
+ let arguments = (ins TestTypeOptionalValueType:$in);
+ let assemblyFormat = "$in `:` type($in) attr-dict";
+}
+
+def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> {
+ let arguments = (ins CompoundNestedOuter:$nested);
+ let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword";
+}
+
+def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
+ let arguments = (ins CompoundNestedOuterType:$nested);
+ let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
+}
+
+def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> {
+ let arguments = (ins CompoundNestedOuterType:$nested);
+ let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword";
+}
+
+//===----------------------------------------------------------------------===//
+// Custom Directives
+
+def FormatCustomDirectiveOperands
+ : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
+ let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
+ Variadic<I64>:$varOperands);
+ let assemblyFormat = [{
+ custom<CustomDirectiveOperands>(
+ $operand, $optOperand, $varOperands
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveOperandsAndTypes
+ : TEST_Op<"format_custom_directive_operands_and_types",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
+ Variadic<AnyType>:$varOperands);
+ let assemblyFormat = [{
+ custom<CustomDirectiveOperandsAndTypes>(
+ $operand, $optOperand, $varOperands,
+ type($operand), type($optOperand), type($varOperands)
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
+ let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$other_regions);
+ let assemblyFormat = [{
+ custom<CustomDirectiveRegions>(
+ $region, $other_regions
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveResults
+ : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
+ let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+ Variadic<AnyType>:$varResults);
+ let assemblyFormat = [{
+ custom<CustomDirectiveResults>(
+ type($result), type($optResult), type($varResults)
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveResultsWithTypeRefs
+ : TEST_Op<"format_custom_directive_results_with_type_refs",
+ [AttrSizedResultSegments]> {
+ let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+ Variadic<AnyType>:$varResults);
+ let assemblyFormat = [{
+ custom<CustomDirectiveResults>(
+ type($result), type($optResult), type($varResults)
+ )
+ custom<CustomDirectiveWithTypeRefs>(
+ ref(type($result)), ref(type($optResult)), ref(type($varResults))
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveWithOptionalOperandRef
+ : TEST_Op<"format_custom_directive_with_optional_operand_ref"> {
+ let arguments = (ins Optional<I64>:$optOperand);
+ let assemblyFormat = [{
+ ($optOperand^)? `:`
+ custom<CustomDirectiveOptionalOperandRef>(ref($optOperand))
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveSuccessors
+ : TEST_Op<"format_custom_directive_successors", [Terminator]> {
+ let successors = (successor AnySuccessor:$successor,
+ VariadicSuccessor<AnySuccessor>:$successors);
+ let assemblyFormat = [{
+ custom<CustomDirectiveSuccessors>(
+ $successor, $successors
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveAttributes
+ : TEST_Op<"format_custom_directive_attributes"> {
+ let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
+ let assemblyFormat = [{
+ custom<CustomDirectiveAttributes>(
+ $attr, $optAttr
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveSpacing
+ : TEST_Op<"format_custom_directive_spacing"> {
+ let arguments = (ins StrAttr:$attr1, StrAttr:$attr2);
+ let assemblyFormat = [{
+ custom<CustomDirectiveSpacing>($attr1)
+ custom<CustomDirectiveSpacing>($attr2)
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveAttrDict
+ : TEST_Op<"format_custom_directive_attrdict"> {
+ let arguments = (ins I64Attr:$attr, OptionalAttr<I64Attr>:$optAttr);
+ let assemblyFormat = [{
+ custom<CustomDirectiveAttrDict>( attr-dict )
+ }];
+}
+
+def FormatLiteralFollowingOptionalGroup
+ : TEST_Op<"format_literal_following_optional_group"> {
+ let arguments = (ins TypeAttr:$type, OptionalAttr<AnyAttr>:$value);
+ let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
+
+def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
+ AllTypesMatch<["value1", "value2", "result"]>
+ ]> {
+ let arguments = (ins AnyType:$value1, AnyType:$value2);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
+}
+
+def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
+ AllTypesMatch<["value1", "value2", "result"]>
+ ]> {
+ let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value1 `,` $value2";
+}
+
+//===----------------------------------------------------------------------===//
+// TypesMatchWith type inference
+
+def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
+ TypesMatchWith<"result type matches operand", "value", "result", "$_self">
+ ]> {
+ let arguments = (ins AnyType:$value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
+def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
+ RangedTypesMatchWith<"result type matches operand", "value", "result",
+ "llvm::make_range($_self.begin(), $_self.end())">
+ ]> {
+ let arguments = (ins Variadic<AnyType>:$value);
+ let results = (outs Variadic<AnyType>:$result);
+ let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
+def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
+ TypesMatchWith<"result type matches constant", "value", "result", "$_self">
+ ]> {
+ let arguments = (ins TypedAttrInterface:$value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value";
+}
+
+def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
+ TypesMatchWith<"tuple result type matches operand type", "value", "result",
+ "::mlir::TupleType::get($_ctxt, $_self)">
+ ]> {
+ let arguments = (ins AnyType:$value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface type inference in assembly format
+
+def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
+ let results = (outs AnyType);
+ let assemblyFormat = "attr-dict";
+
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+ return ::mlir::success();
+ }
+ }];
+}
+
+// Check that formatget supports DeclareOpInterfaceMethods.
+def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let results = (outs AnyType);
+ let assemblyFormat = "attr-dict";
+}
+
+// Base class for testing mixing allOperandTypes, allOperands, and
+// inferResultTypes.
+class FormatInferAllTypesBaseOp<string mnemonic, list<Trait> traits = []>
+ : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let results = (outs Variadic<AnyType>:$outs);
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ ::mlir::TypeRange operandTypes = operands.getTypes();
+ inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
+ return ::mlir::success();
+ }
+ }];
+}
+
+// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
+def FormatInferTypeAllOperandsAndTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there is one
+// ODS operand.
+def FormatInferTypeAllOperandsAndTypesOneOperandOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
+ let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there are
+// more than one ODS operands.
+def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
+ let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperands is true and operand types
+// are separately specified.
+def FormatInferTypeAllTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
+}
+
+// Test inferReturnTypes coupled with regions.
+def FormatInferTypeRegionsOp
+ : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> {
+ let results = (outs Variadic<AnyType>:$outs);
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = "$region attr-dict";
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ if (regions.empty())
+ return ::mlir::failure();
+ auto types = regions.front()->getArgumentTypes();
+ inferredReturnTypes.assign(types.begin(), types.end());
+ return ::mlir::success();
+ }
+ }];
+}
+
+// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes).
+def FormatInferTypeVariadicOperandsOp
+ : TEST_Op<"format_infer_type_variadic_operands",
+ [InferTypeOpInterface, AttrSizedOperandSegments]> {
+ let arguments = (ins Variadic<I32>:$a, Variadic<I64>:$b);
+ let results = (outs Variadic<AnyType>:$outs);
+ let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict";
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ FormatInferTypeVariadicOperandsOpAdaptor adaptor(
+ operands, attributes, *properties.as<Properties *>(), {});
+ auto aTypes = adaptor.getA().getTypes();
+ auto bTypes = adaptor.getB().getTypes();
+ inferredReturnTypes.append(aTypes.begin(), aTypes.end());
+ inferredReturnTypes.append(bTypes.begin(), bTypes.end());
+ return ::mlir::success();
+ }
+ }];
+}
+
+#endif // TEST_OPS_SYNTAX
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index e4dc4cda7608ac..aae8c1ac44dbf7 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOpsSyntax.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index a51e5b63f258b8..a33e6d21cef5ea 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -122,6 +122,28 @@ td_library(
],
)
+gentbl_cc_library(
+ name = "TestOpsSyntaxIncGen",
+ strip_include_prefix = "lib/Dialect/Test",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "lib/Dialect/Test/TestOpsSyntax.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "lib/Dialect/Test/TestOpsSyntax.cpp.inc",
+ ),
+ ],
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Test/TestOpsSyntax.td",
+ test = True,
+ deps = [
+ ":TestOpTdFiles",
+ ],
+)
+
+
gentbl_cc_library(
name = "TestOpsIncGen",
strip_include_prefix = "lib/Dialect/Test",
@@ -354,6 +376,7 @@ cc_library(
":TestEnumDefsIncGen",
":TestInterfacesIncGen",
":TestOpsIncGen",
+ ":TestOpsSyntaxIncGen",
":TestTypeDefsIncGen",
"//llvm:Support",
"//mlir:ArithDialect",
More information about the Mlir-commits
mailing list