[llvm-branch-commits] [llvm] [mlir] [mlir][test] Shard and reorganize the test dialect (PR #89424)
Jeff Niu via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Apr 19 10:52:14 PDT 2024
https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/89424
Shard the test dialect by 4. This patch also reorganizes the manually-written op hooks into `TestOpDefs.cpp` and format custom directive parser and printers into `TestFormatUtils`, adds missing comment blocks, and moves around where generated source files are included for types, attributes, enums, etc.
In my case, the compilation time of the test dialect drops from >60s to ~10s.
>From 2def46c947500e47bd5303c2bf004161c4d04558 Mon Sep 17 00:00:00 2001
From: Mogball <jeffniu22 at gmail.com>
Date: Thu, 23 Jun 2022 20:45:26 +0000
Subject: [PATCH] [mlir][test] Shard and reorganize the test dialect
Shard the test dialect by 4. This patch also reorganizes the manually-written
op hooks into `TestOpDefs.cpp` and format custom directive parser and printers
into `TestFormatUtils`, adds missing comment blocks, and moves around where
generated source files are included for types, attributes, enums, etc.
In my case, the compilation time of the test dialect drops from >60s to ~10s.
---
.../TestDenseBackwardDataFlowAnalysis.cpp | 1 +
.../TestDenseForwardDataFlowAnalysis.cpp | 1 +
.../FuncToLLVM/TestConvertCallOp.cpp | 1 +
.../TestOneToNTypeConversionPass.cpp | 1 +
.../Dialect/Affine/TestReifyValueBounds.cpp | 1 +
.../lib/Dialect/DLTI/TestDataLayoutQuery.cpp | 2 +-
.../Func/TestDecomposeCallGraphTypes.cpp | 1 +
mlir/test/lib/Dialect/Test/CMakeLists.txt | 9 +-
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 3 +-
mlir/test/lib/Dialect/Test/TestDialect.cpp | 1451 ++---------------
mlir/test/lib/Dialect/Test/TestDialect.h | 50 +-
.../Dialect/Test/TestDialectInterfaces.cpp | 1 +
.../test/lib/Dialect/Test/TestFormatUtils.cpp | 377 +++++
mlir/test/lib/Dialect/Test/TestFormatUtils.h | 211 +++
.../Test/TestFromLLVMIRTranslation.cpp | 1 +
mlir/test/lib/Dialect/Test/TestInterfaces.cpp | 2 +
mlir/test/lib/Dialect/Test/TestInterfaces.h | 2 +
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 1161 +++++++++++++
mlir/test/lib/Dialect/Test/TestOps.cpp | 17 +
mlir/test/lib/Dialect/Test/TestOps.h | 149 ++
mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp | 1 +
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 +
.../Dialect/Test/TestToLLVMIRTranslation.cpp | 1 +
mlir/test/lib/Dialect/Test/TestTraits.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestTypes.cpp | 1 +
mlir/test/lib/Dialect/Test/TestTypes.h | 6 +-
mlir/test/lib/IR/TestBytecodeRoundtrip.cpp | 1 +
mlir/test/lib/IR/TestClone.cpp | 2 +-
mlir/test/lib/IR/TestSideEffects.cpp | 2 +-
mlir/test/lib/IR/TestSymbolUses.cpp | 2 +-
mlir/test/lib/IR/TestTypes.cpp | 1 +
mlir/test/lib/IR/TestVisitorsGeneric.cpp | 2 +-
mlir/test/lib/Pass/TestPassManager.cpp | 1 +
mlir/test/lib/Transforms/TestInlining.cpp | 1 +
.../Transforms/TestMakeIsolatedFromAbove.cpp | 1 +
mlir/unittests/IR/AdaptorTest.cpp | 1 +
mlir/unittests/IR/IRMapping.cpp | 1 +
mlir/unittests/IR/InterfaceAttachmentTest.cpp | 1 +
mlir/unittests/IR/InterfaceTest.cpp | 1 +
mlir/unittests/IR/OperationSupportTest.cpp | 1 +
mlir/unittests/IR/PatternMatchTest.cpp | 1 +
mlir/unittests/TableGen/OpBuildGen.cpp | 1 +
.../mlir/test/BUILD.bazel | 27 +-
43 files changed, 2136 insertions(+), 1365 deletions(-)
create mode 100644 mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestFormatUtils.h
create mode 100644 mlir/test/lib/Dialect/Test/TestOpDefs.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestOps.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestOps.h
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index ca052392f2f5f2..65592a5c5d698b 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 29480f5ad63ee0..3f9ce2dc0bc50a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
index 5e17779660f392..f878a262512ee8 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 3c4067b35d8e5b..cc1af59c5e15bb 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index b098a5a23fd316..34513cd418e4c2 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index 84f45b31603192..56f309f150ca5d 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 10aba733bd5696..0d7dce2240f4cb 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index d246c0492a3bd5..fab89378093326 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -31,8 +31,6 @@ mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTestEnumDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
-mlir_tablegen(TestOps.h.inc -gen-op-decls)
-mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
@@ -43,16 +41,22 @@ mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls)
mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)
+add_sharded_ops(TestOps 20)
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
TestDialect.cpp
+ TestFormatUtils.cpp
TestInterfaces.cpp
+ TestOpDefs.cpp
+ TestOps.cpp
TestPatterns.cpp
TestTraits.cpp
TestTypes.cpp
TestOpsSyntax.cpp
TestDialectInterfaces.cpp
+ ${SHARDED_SRCS}
EXCLUDE_FROM_LIBMLIR
@@ -63,6 +67,7 @@ add_mlir_library(MLIRTestDialect
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
MLIRTestOpsSyntaxIncGen
+ MLIRTestOpsShardGen
LINK_LIBS PUBLIC
MLIRControlFlowInterfaces
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d41d495c38e553..2cc051e664beec 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
@@ -244,7 +245,7 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
//===----------------------------------------------------------------------===//
#include "TestAttrInterfaces.cpp.inc"
-
+#include "TestOpEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a23ed89c4b04d1..21c46fc807aaa8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -7,8 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
-#include "TestAttributes.h"
-#include "TestInterfaces.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -39,17 +38,85 @@
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
#include <numeric>
#include <optional>
-// Include this before the using namespace lines below to
-// test that we don't have namespace dependencies.
+// Include this before the using namespace lines below to test that we don't
+// have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
using namespace mlir;
using namespace test;
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
+ Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set TestProperties";
+ return failure();
+ }
+ auto label = dict.getAs<mlir::StringAttr>("label");
+ if (!label) {
+ emitError() << "expected StringAttr for key `label`";
+ return failure();
+ }
+ auto valueAttr = dict.getAs<IntegerAttr>("value");
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+
+ prop.label = std::make_shared<std::string>(label.getValue());
+ prop.value = valueAttr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr
+test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
+ attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
+ return llvm::hash_combine(prop.value, StringRef(*prop.label));
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop) {
+ p.printKeywordOrString(*prop.label);
+ p << " is " << prop.value;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop) {
+ std::string label;
+ if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
+ parser.parseInteger(prop.value))
+ return failure();
+ prop.label = std::make_shared<std::string>(std::move(label));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
@@ -70,8 +137,8 @@ llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MyPropStruct &prop) {
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MyPropStruct &prop) {
StringRef str;
if (failed(reader.readString(str)))
return failure();
@@ -79,13 +146,71 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- MyPropStruct &prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ MyPropStruct &prop) {
writer.writeOwnedString(prop.content);
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MutableArrayRef<int64_t> prop) {
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set VersionedProperties";
+ return failure();
+ }
+ auto value1Attr = dict.getAs<IntegerAttr>("value1");
+ if (!value1Attr) {
+ emitError() << "expected IntegerAttr for key `value1`";
+ return failure();
+ }
+ auto value2Attr = dict.getAs<IntegerAttr>("value2");
+ if (!value2Attr) {
+ emitError() << "expected IntegerAttr for key `value2`";
+ return failure();
+ }
+
+ prop.value1 = value1Attr.getValue().getSExtValue();
+ prop.value2 = value2Attr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const VersionedProperties &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
+ attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const VersionedProperties &prop) {
+ return llvm::hash_combine(prop.value1, prop.value2);
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const VersionedProperties &prop) {
+ p << prop.value1 << " | " << prop.value2;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ VersionedProperties &prop) {
+ if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
+ parser.parseInteger(prop.value2))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MutableArrayRef<int64_t> prop) {
uint64_t size;
if (failed(reader.readVarInt(size)))
return failure();
@@ -101,45 +226,13 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- ArrayRef<int64_t> prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ ArrayRef<int64_t> prop) {
writer.writeVarInt(prop.size());
for (auto elt : prop)
writer.writeVarInt(elt);
}
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop);
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop);
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx,
- const VersionedProperties &prop);
-static llvm::hash_code computeHash(const VersionedProperties &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop);
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions);
-
-void test::registerTestDialect(DialectRegistry ®istry) {
- registry.insert<TestDialect>();
-}
-
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
@@ -196,9 +289,20 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
// TestDialect
//===----------------------------------------------------------------------===//
-static void testSideEffectOpGetEffect(
+void test::registerTestDialect(DialectRegistry ®istry) {
+ registry.insert<TestDialect>();
+}
+
+void test::testSideEffectOpGetEffect(
Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
+ SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
+ &effects) {
+ auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
+ if (!effectsAttr)
+ return;
+
+ effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
+}
// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
struct TestOpEffectInterfaceFallback
@@ -228,6 +332,7 @@ void TestDialect::initialize() {
>();
registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
+ registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
@@ -318,57 +423,6 @@ TestDialect::getOperationPrinter(Operation *op) const {
return {};
}
-//===----------------------------------------------------------------------===//
-// TypedAttrOp
-//===----------------------------------------------------------------------===//
-
-/// Parse an attribute with a given type.
-static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
- Attribute &attr) {
- return parser.parseAttribute(attr, type.getValue());
-}
-
-/// Print an attribute without its type.
-static void printAttrElideType(AsmPrinter &printer, Operation *op,
- TypeAttr type, Attribute attr) {
- printer.printAttributeWithoutType(attr);
-}
-
-//===----------------------------------------------------------------------===//
-// TestBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
- assert(index == 0 && "invalid successor index");
- return SuccessorOperands(getTargetOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 1)
- return SuccessorOperands(getFirstOperandsMutable());
- return SuccessorOperands(getSecondOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 0)
- return SuccessorOperands(0, getSuccessOperandsMutable());
- return SuccessorOperands(1, getErrorOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestDialectCanonicalizerOp
-//===----------------------------------------------------------------------===//
-
static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
PatternRewriter &rewriter) {
@@ -381,1206 +435,3 @@ void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
-
-//===----------------------------------------------------------------------===//
-// TestCallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // Check that the callee attribute was specified.
- auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' symbol reference attribute");
- if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
- return emitOpError() << "'" << fnAttr.getValue()
- << "' does not reference a valid function";
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ConversionFuncOp
-//===----------------------------------------------------------------------===//
-
-ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
- OperationState &result) {
- auto buildFuncType =
- [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
- function_interface_impl::VariadicFlag,
- std::string &) { return builder.getFunctionType(argTypes, results); };
-
- return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
-}
-
-void ConversionFuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(
- p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
- getArgAttrsAttrName(), getResAttrsAttrName());
-}
-
-//===----------------------------------------------------------------------===//
-// TestFoldToCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
- using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(FoldToCallOp op,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
- op.getCalleeAttr(), ValueRange());
- return success();
- }
-};
-} // namespace
-
-void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldToCallOpPattern>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// Test IsolatedRegionOp - parse passthrough region arguments.
-//===----------------------------------------------------------------------===//
-
-ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Parse the input operand.
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (parser.parseOperand(argInfo.ssaName) ||
- parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
-}
-
-void IsolatedRegionOp::print(OpAsmPrinter &p) {
- p << ' ';
- p.printOperand(getOperand());
- p.shadowRegionArgs(getRegion(), getOperand());
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test SSACFGRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
- return RegionKind::SSACFG;
-}
-
-//===----------------------------------------------------------------------===//
-// Test GraphRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind GraphRegionOp::getRegionKind(unsigned index) {
- return RegionKind::Graph;
-}
-
-//===----------------------------------------------------------------------===//
-// Test AffineScopeOp
-//===----------------------------------------------------------------------===//
-
-ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
-}
-
-void AffineScopeOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test OptionalCustomAttrOp
-//===----------------------------------------------------------------------===//
-
-static OptionalParseResult parseOptionalCustomParser(AsmParser &p,
- IntegerAttr &result) {
- if (succeeded(p.parseOptionalKeyword("foo")))
- return p.parseAttribute(result);
- return {};
-}
-
-static void printOptionalCustomParser(AsmPrinter &p, Operation *,
- IntegerAttr result) {
- p << "foo ";
- p.printAttribute(result);
-}
-
-//===----------------------------------------------------------------------===//
-// ReifyBoundOp
-//===----------------------------------------------------------------------===//
-
-::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
- if (getType() == "EQ")
- return ::mlir::presburger::BoundType::EQ;
- if (getType() == "LB")
- return ::mlir::presburger::BoundType::LB;
- if (getType() == "UB")
- return ::mlir::presburger::BoundType::UB;
- llvm_unreachable("invalid bound type");
-}
-
-LogicalResult ReifyBoundOp::verify() {
- if (isa<ShapedType>(getVar().getType())) {
- if (!getDim().has_value())
- return emitOpError("expected 'dim' attribute for shaped type variable");
- } else if (getVar().getType().isIndex()) {
- if (getDim().has_value())
- return emitOpError("unexpected 'dim' attribute for index variable");
- } else {
- return emitOpError("expected index-typed variable or shape type variable");
- }
- if (getConstant() && getScalable())
- return emitOpError("'scalable' and 'constant' are mutually exlusive");
- if (getScalable() != getVscaleMin().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- if (getScalable() != getVscaleMax().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- return success();
-}
-
-::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
- if (getDim().has_value())
- return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
- return ValueBoundsConstraintSet::Variable(getVar());
-}
-
-::mlir::ValueBoundsConstraintSet::ComparisonOperator
-CompareOp::getComparisonOperator() {
- if (getCmp() == "EQ")
- return ValueBoundsConstraintSet::ComparisonOperator::EQ;
- if (getCmp() == "LT")
- return ValueBoundsConstraintSet::ComparisonOperator::LT;
- if (getCmp() == "LE")
- return ValueBoundsConstraintSet::ComparisonOperator::LE;
- if (getCmp() == "GT")
- return ValueBoundsConstraintSet::ComparisonOperator::GT;
- if (getCmp() == "GE")
- return ValueBoundsConstraintSet::ComparisonOperator::GE;
- llvm_unreachable("invalid comparison operator");
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
- if (!getLhsMap())
- return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(0, getLhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
- int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- if (!getRhsMap())
- return ValueBoundsConstraintSet::Variable(
- getVarOperands()[rhsOperandsBegin]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
-}
-
-LogicalResult CompareOp::verify() {
- if (getCompose() && (getLhsMap() || getRhsMap()))
- return emitOpError(
- "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
- int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
- if (getVarOperands().size() != size_t(expectedNumOperands))
- return emitOpError("expected ")
- << expectedNumOperands << " operands, but got "
- << getVarOperands().size();
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test removing op with inner ops.
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct TestRemoveOpWithInnerOps
- : public OpRewritePattern<TestOpWithRegionPattern> {
- using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
-
- void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
-
- LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
- PatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-} // namespace
-
-void TestOpWithRegionPattern::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<TestRemoveOpWithInnerOps>(context);
-}
-
-OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
- return getOperand();
-}
-
-OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
-
-LogicalResult TestOpWithVariadicResultsAndFolder::fold(
- FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
- for (Value input : this->getOperands()) {
- results.push_back(input);
- }
- return success();
-}
-
-OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
- // Exercise the fact that an operation created with createOrFold should be
- // allowed to access its parent block.
- assert(getOperation()->getBlock() &&
- "expected that operation is not unlinked");
-
- if (adaptor.getOp() && !getProperties().attr) {
- // The folder adds "attr" if not present.
- getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
- return getResult();
- }
- return {};
-}
-
-OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
- int64_t sum = 0;
- if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
- sum += value.getValue().getSExtValue();
-
- for (Attribute attr : adaptor.getVariadic())
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 2 * value.getValue().getSExtValue();
-
- for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
- for (Attribute attr : attrs)
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 3 * value.getValue().getSExtValue();
-
- sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
-
- return IntegerAttr::get(getType(), sum);
-}
-
-LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- inferredReturnTypes.assign({operands[0].getType()});
- return success();
-}
-
-LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location,
- OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (adaptor.getX().getType() != adaptor.getY().getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- adaptor.getX().getType(), " vs ",
- adaptor.getY().getType());
- }
- inferredReturnTypes.assign({adaptor.getX().getType()});
- return success();
-}
-
-// TODO: We should be able to only define either inferReturnType or
-// refineReturnType, currently only refineReturnType can be omitted.
-LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- returnTypes.clear();
- return OpWithRefineTypeInterfaceOp::refineReturnTypes(
- context, location, operands, attributes, properties, regions,
- returnTypes);
-}
-
-LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- // TODO: Add helper to make this more concise to write.
- if (returnTypes.empty())
- returnTypes.resize(1, nullptr);
- if (returnTypes[0] && returnTypes[0] != operands[0].getType())
- return emitOptionalError(location,
- "required first operand and result to match");
- returnTypes[0] = operands[0].getType();
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = operands.front().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = adaptor.getOperand1().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- Location loc = getLoc();
- shapes.reserve(operands.size());
- for (Value operand : llvm::reverse(operands)) {
- auto rank = cast<RankedTensorType>(operand.getType()).getRank();
- auto currShape = llvm::to_vector<4>(
- llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
- return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
- }));
- shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
- currShape));
- }
- return success();
-}
-
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
- OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
- Location loc = getLoc();
- shapes.reserve(getNumOperands());
- for (Value operand : llvm::reverse(getOperands())) {
- auto tensorType = cast<RankedTensorType>(operand.getType());
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, tensorType.getRank()),
- [&](int64_t dim) -> OpFoldResult {
- return tensorType.isDynamicDim(dim)
- ? static_cast<OpFoldResult>(
- builder.createOrFold<tensor::DimOp>(loc, operand,
- dim))
- : static_cast<OpFoldResult>(
- builder.getIndexAttr(tensorType.getDimSize(dim)));
- }));
- shapes.emplace_back(std::move(currShape));
- }
- return success();
-}
-
-LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
- MLIRContext *context, std::optional<Location>, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
-
- Adaptor adaptor(operands, attributes, properties, regions);
- inferredReturnTypes.push_back(IntegerType::get(
- context, adaptor.getLhs() + adaptor.getProperties().rhs));
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test SideEffect interfaces
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// A test resource for side effects.
-struct TestResource : public SideEffects::Resource::Base<TestResource> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
-
- StringRef getName() final { return "<Test>"; }
-};
-} // namespace
-
-static void testSideEffectOpGetEffect(
- Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
- &effects) {
- auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
- if (!effectsAttr)
- return;
-
- effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- // Check for an effects attribute on the op instance.
- ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
- if (!effectsAttr)
- return;
-
- // If there is one, it is an array of dictionary attributes that hold
- // information on the effects of this operation.
- for (Attribute element : effectsAttr) {
- DictionaryAttr effectElement = cast<DictionaryAttr>(element);
-
- // Get the specific memory effect.
- MemoryEffects::Effect *effect =
- StringSwitch<MemoryEffects::Effect *>(
- cast<StringAttr>(effectElement.get("effect")).getValue())
- .Case("allocate", MemoryEffects::Allocate::get())
- .Case("free", MemoryEffects::Free::get())
- .Case("read", MemoryEffects::Read::get())
- .Case("write", MemoryEffects::Write::get());
-
- // Check for a non-default resource to use.
- SideEffects::Resource *resource = SideEffects::DefaultResource::get();
- if (effectElement.get("test_resource"))
- resource = TestResource::get();
-
- // Check for a result to affect.
- if (effectElement.get("on_result"))
- effects.emplace_back(effect, getResult(), resource);
- else if (Attribute ref = effectElement.get("on_reference"))
- effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
- else
- effects.emplace_back(effect, resource);
- }
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<TestEffects::EffectInstance> &effects) {
- testSideEffectOpGetEffect(getOperation(), effects);
-}
-
-//===----------------------------------------------------------------------===//
-// StringAttrPrettyNameOp
-//===----------------------------------------------------------------------===//
-
-// This op has fancy handling of its SSA result name.
-ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Add the result types.
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
- result.addTypes(parser.getBuilder().getIntegerType(32));
-
- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
- return failure();
-
- // If the attribute dictionary contains no 'names' attribute, infer it from
- // the SSA name (if specified).
- bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
- return attr.getName() == "names";
- });
-
- // If there was no name specified, check to see if there was a useful name
- // specified in the asm file.
- if (hadNames || parser.getNumResults() == 0)
- return success();
-
- SmallVector<StringRef, 4> names;
- auto *context = result.getContext();
-
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
- auto resultName = parser.getResultName(i);
- StringRef nameStr;
- if (!resultName.first.empty() && !isdigit(resultName.first[0]))
- nameStr = resultName.first;
-
- names.push_back(nameStr);
- }
-
- auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
- result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
- return success();
-}
-
-void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
- // Note that we only need to print the "name" attribute if the asmprinter
- // result name disagrees with it. This can happen in strange cases, e.g.
- // when there are conflicts.
- bool namesDisagree = getNames().size() != getNumResults();
-
- SmallString<32> resultNameStr;
- for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
- resultNameStr.clear();
- llvm::raw_svector_ostream tmpStream(resultNameStr);
- p.printOperand(getResult(i), tmpStream);
-
- auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
- if (!expectedName ||
- tmpStream.str().drop_front() != expectedName.getValue()) {
- namesDisagree = true;
- }
- }
-
- if (namesDisagree)
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
- else
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
-}
-
-// We set the SSA name in the asm syntax to the contents of the name
-// attribute.
-void StringAttrPrettyNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
-
- auto value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.getValue().empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-void CustomResultsNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- ArrayAttr value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// ResultTypeWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ResultTypeWithTraitOp::verify() {
- if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
- return success();
- return emitError("result type should have trait 'TestTypeTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// AttrWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult AttrWithTraitOp::verify() {
- if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
- return success();
- return emitError("'attr' attribute should have trait 'TestAttrTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// RegionIfOp
-//===----------------------------------------------------------------------===//
-
-void RegionIfOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printOperands(getOperands());
- p << ": " << getOperandTypes();
- p.printArrowTypeList(getResultTypes());
- p << " then ";
- p.printRegion(getThenRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " else ";
- p.printRegion(getElseRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " join ";
- p.printRegion(getJoinRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
-}
-
-ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
- SmallVector<Type, 2> operandTypes;
-
- result.regions.reserve(3);
- Region *thenRegion = result.addRegion();
- Region *elseRegion = result.addRegion();
- Region *joinRegion = result.addRegion();
-
- // Parse operand, type and arrow type lists.
- if (parser.parseOperandList(operandInfos) ||
- parser.parseColonTypeList(operandTypes) ||
- parser.parseArrowTypeList(result.types))
- return failure();
-
- // Parse all attached regions.
- if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
- parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
- parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
- return failure();
-
- return parser.resolveOperands(operandInfos, operandTypes,
- parser.getCurrentLocation(), result.operands);
-}
-
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
- "invalid region index");
- return getOperands();
-}
-
-void RegionIfOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // We always branch to the join region.
- if (!point.isParent()) {
- if (point != getJoinRegion())
- regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
- else
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
- // The then and else regions are the entry regions of this op.
- regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
- regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
-}
-
-void RegionIfOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- // Each region is invoked at most once.
- invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
-}
-
-//===----------------------------------------------------------------------===//
-// AnyCondOp
-//===----------------------------------------------------------------------===//
-
-void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
- SmallVectorImpl<RegionSuccessor> ®ions) {
- // The parent op branches into the only region, and the region branches back
- // to the parent op.
- if (point.isParent())
- regions.emplace_back(&getRegion());
- else
- regions.emplace_back(getResults());
-}
-
-void AnyCondOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- invocationBounds.emplace_back(1, 1);
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockOp
-//===----------------------------------------------------------------------===//
-
-void LoopBlockOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- regions.emplace_back(&getBody(), getBody().getArguments());
- if (point.isParent())
- return;
-
- regions.emplace_back((*this)->getResults());
-}
-
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
- return MutableOperandRange(getInitMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockTerminatorOp
-//===----------------------------------------------------------------------===//
-
-MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
- return getExitArgMutable();
- return getNextIterArgMutable();
-}
-
-//===----------------------------------------------------------------------===//
-// SwitchWithNoBreakOp
-//===----------------------------------------------------------------------===//
-
-void TestNoTerminatorOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
-
-//===----------------------------------------------------------------------===//
-// SingleNoTerminatorCustomAsmOp
-//===----------------------------------------------------------------------===//
-
-ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
- OperationState &state) {
- Region *body = state.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
- return success();
-}
-
-void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
- printer.printRegion(
- getRegion(), /*printEntryBlockArgs=*/false,
- // This op has a single block without terminators. But explicitly mark
- // as not printing block terminators for testing.
- /*printBlockTerminators=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// TestVerifiersOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestVerifiersOp::verify() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- Operation *definingOp = getInput().getDefiningOp();
- if (definingOp && failed(mlir::verify(definingOp)))
- return emitOpError("operand hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of verifier");
-
- return success();
-}
-
-LogicalResult TestVerifiersOp::verifyRegions() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- for (Block &block : getRegion())
- for (Operation &op : block)
- if (failed(mlir::verify(&op)))
- return emitOpError("nested op hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of region verifier");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test InferIntRangeInterface
-//===----------------------------------------------------------------------===//
-
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- // Parse the input argument
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (failed(parser.parseArgument(argInfo)))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
-}
-
-void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
- p.printOptionalAttrDict((*this)->getAttrs());
- p << ' ';
- p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
- /*omitType=*/true);
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-void TestWithBoundsRegionOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- Value arg = getRegion().getArgument(0);
- setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- APInt one(range.umin().getBitWidth(), 1);
- setResultRanges(getResult(),
- {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
- range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
-}
-
-void TestReflectBoundsOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- MLIRContext *ctx = getContext();
- Builder b(ctx);
- setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
- setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
- setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
- setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
- setResultRanges(getResult(), range);
-}
-
-OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
- // Just a simple fold for testing purposes that reads an operands constant
- // value and returns it.
- if (!attributes.empty())
- return attributes.front();
- return nullptr;
-}
-
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set TestProperties";
- return failure();
- }
- auto label = dict.getAs<mlir::StringAttr>("label");
- if (!label) {
- emitError() << "expected StringAttr for key `label`";
- return failure();
- }
- auto valueAttr = dict.getAs<IntegerAttr>("value");
- if (!valueAttr) {
- emitError() << "expected IntegerAttr for key `value`";
- return failure();
- }
-
- prop.label = std::make_shared<std::string>(label.getValue());
- prop.value = valueAttr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
- attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
- return llvm::hash_combine(prop.value, StringRef(*prop.label));
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop) {
- p.printKeywordOrString(*prop.label);
- p << " is " << prop.value;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop) {
- std::string label;
- if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
- parser.parseInteger(prop.value))
- return failure();
- prop.label = std::make_shared<std::string>(std::move(label));
- return success();
-}
-
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
- SmallVector<int64_t> caseValues;
- while (succeeded(p.parseOptionalKeyword("case"))) {
- int64_t value;
- Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
- if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
- return failure();
- caseValues.push_back(value);
- }
- cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
- return success();
-}
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions) {
- for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
- p.printNewline();
- p << "case " << value << ' ';
- p.printRegion(*region, /*printEntryBlockArgs=*/false);
- }
-}
-
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set VersionedProperties";
- return failure();
- }
- auto value1Attr = dict.getAs<IntegerAttr>("value1");
- if (!value1Attr) {
- emitError() << "expected IntegerAttr for key `value1`";
- return failure();
- }
- auto value2Attr = dict.getAs<IntegerAttr>("value2");
- if (!value2Attr) {
- emitError() << "expected IntegerAttr for key `value2`";
- return failure();
- }
-
- prop.value1 = value1Attr.getValue().getSExtValue();
- prop.value2 = value2Attr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
- attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const VersionedProperties &prop) {
- return llvm::hash_combine(prop.value1, prop.value2);
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop) {
- p << prop.value1 << " | " << prop.value2;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop) {
- if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
- parser.parseInteger(prop.value2))
- return failure();
- return success();
-}
-
-static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
- return parser.parseLSquare() || parser.parseInteger(value[0]) ||
- parser.parseComma() || parser.parseInteger(value[1]) ||
- parser.parseComma() || parser.parseInteger(value[2]) ||
- parser.parseRSquare();
-}
-
-static void printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
- ArrayRef<int64_t> value) {
- printer << '[' << value << ']';
-}
-
-static bool parseIntProperty(OpAsmParser &parser, int64_t &value) {
- return failed(parser.parseInteger(value));
-}
-
-static void printIntProperty(OpAsmPrinter &printer, Operation *op,
- int64_t value) {
- printer << value;
-}
-
-static bool parseSumProperty(OpAsmParser &parser, int64_t &second,
- int64_t first) {
- int64_t sum;
- auto loc = parser.getCurrentLocation();
- if (parser.parseInteger(second) || parser.parseEqual() ||
- parser.parseInteger(sum))
- return true;
- if (sum != second + first) {
- parser.emitError(loc, "Expected sum to equal first + second");
- return true;
- }
- return false;
-}
-
-static void printSumProperty(OpAsmPrinter &printer, Operation *op,
- int64_t second, int64_t first) {
- printer << second << " = " << (second + first);
-}
-
-//===----------------------------------------------------------------------===//
-// Tensor/Buffer Ops
-//===----------------------------------------------------------------------===//
-
-void ReadBufferOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- // The buffer operand is read.
- effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
- SideEffects::DefaultResource::get());
- // The buffer contents are dumped.
- effects.emplace_back(MemoryEffects::Write::get(),
- SideEffects::DefaultResource::get());
-}
-
-//===----------------------------------------------------------------------===//
-// Test Dataflow
-//===----------------------------------------------------------------------===//
-
-CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallAndStoreOp::getArgOperands() {
- return getCalleeOperands();
-}
-
-MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
- return getCalleeOperandsMutable();
-}
-
-CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
- return getForwardedOperands();
-}
-
-MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
- return getForwardedOperandsMutable();
-}
-
-void TestStoreWithARegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.isParent())
- regions.emplace_back(&getBody(), getBody().front().getArguments());
- else
- regions.emplace_back();
-}
-
-void TestStoreWithALoopRegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // Both the operation itself and the region may be branching into the body or
- // back into the operation itself. It is possible for the operation not to
- // enter the body.
- regions.emplace_back(
- RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
-}
-
-LogicalResult
-TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
- ::mlir::OperationState &state) {
- auto &prop = state.getOrAddProperties<Properties>();
- if (::mlir::failed(reader.readAttribute(prop.dims)))
- return ::mlir::failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- return success();
- }
- }
-
- if (::mlir::failed(reader.readAttribute(prop.modifier)))
- return ::mlir::failure();
- return ::mlir::success();
-}
-
-void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
- auto &prop = getProperties();
- writer.writeAttribute(prop.dims);
-
- auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to write.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- llvm::outs() << "downgrading op properties...\n";
- return;
- }
- }
- writer.writeAttribute(prop.modifier);
-}
-
-::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
- ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
- uint64_t value1, value2 = 0;
- if (failed(reader.readVarInt(value1)))
- return failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- bool needToParseAnotherInt = true;
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2))
- needToParseAnotherInt = false;
- }
- if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
- return failure();
-
- prop.value1 = value1;
- prop.value2 = value2;
- return success();
-}
-
-void TestOpWithVersionedProperties::writeToMlirBytecode(
- ::mlir::DialectBytecodeWriter &writer,
- const test::VersionedProperties &prop) {
- writer.writeVarInt(prop.value1);
- writer.writeVarInt(prop.value2);
-}
-
-#include "TestOpEnums.cpp.inc"
-#include "TestOpInterfaces.cpp.inc"
-#include "TestTypeInterfaces.cpp.inc"
-
-#define GET_OP_CLASSES
-#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index d5b2fbeafc4104..c05e15fc642a25 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -43,19 +43,18 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include <memory>
namespace mlir {
-class DLTIDialect;
class RewritePatternSet;
-} // namespace mlir
+} // end namespace mlir
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
-#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"
namespace test {
@@ -75,49 +74,8 @@ struct TestDialectVersion : public mlir::DialectVersion {
uint32_t minor_ = 0;
};
-// Define some classes to exercises the Properties feature.
-
-struct PropertiesWithCustomPrint {
- /// A shared_ptr to a const object is safe: it is equivalent to a value-based
- /// member. Here the label will be deallocated when the last operation
- /// refering to it is destroyed. However there is no pool-allocation: this is
- /// offloaded to the client.
- std::shared_ptr<const std::string> label;
- int value;
- bool operator==(const PropertiesWithCustomPrint &rhs) const {
- return value == rhs.value && *label == *rhs.label;
- }
-};
-class MyPropStruct {
-public:
- std::string content;
- // These three methods are invoked through the `MyStructProperty` wrapper
- // defined in TestOps.td
- mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
- static mlir::LogicalResult
- setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
- llvm::hash_code hash() const;
- bool operator==(const MyPropStruct &rhs) const {
- return content == rhs.content;
- }
-};
-struct VersionedProperties {
- // For the sake of testing, assume that this object was associated to version
- // 1.2 of the test dialect when having only one int value. In the current
- // version 2.0, the property has two values. We also assume that the class is
- // upgrade-able if value2 = 0.
- int value1;
- int value2;
- bool operator==(const VersionedProperties &rhs) const {
- return value1 == rhs.value1 && value2 == rhs.value2;
- }
-};
} // namespace test
-#define GET_OP_CLASSES
-#include "TestOps.h.inc"
-
namespace test {
// Op deliberately defined in C++ code rather than ODS to test that C++
@@ -138,6 +96,10 @@ class ManualCppOpWithFold
void registerTestDialect(::mlir::DialectRegistry ®istry);
void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns);
+void testSideEffectOpGetEffect(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::TestEffects::Effect>> &effects);
} // namespace test
#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..a3a8913d5964c6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
new file mode 100644
index 00000000000000..6e75dd39322810
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
@@ -0,0 +1,377 @@
+//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestFormatUtils.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperands(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
+ if (parser.parseOperand(operand))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseOperandList(varOperands) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+ Value operand, Value optOperand,
+ OperandRange varOperands) {
+ printer << operand;
+ if (optOperand)
+ printer << ", " << optOperand;
+ printer << " -> (" << varOperands << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+ Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(operandType))
+ return failure();
+ if (succeeded(parser.parseOptionalComma()))
+ if (parser.parseType(optOperandType))
+ return failure();
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+ Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " : " << operandType;
+ if (optOperandType)
+ printer << ", " << optOperandType;
+ printer << " -> (" << varOperandTypes << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveWithTypeRefs(
+ OpAsmParser &parser, Type operandType, Type optOperandType,
+ const SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseKeyword("type_refs_capture"))
+ return failure();
+
+ Type operandType2, optOperandType2;
+ SmallVector<Type, 1> varOperandTypes2;
+ if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+ varOperandTypes2))
+ return failure();
+
+ if (operandType != operandType2 || optOperandType != optOperandType2 ||
+ varOperandTypes != varOperandTypes2)
+ return failure();
+
+ return success();
+}
+
+void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+ Operation *op, Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " type_refs_capture ";
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperandsAndTypes(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
+ Type &operandType, Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+ parseCustomDirectiveResults(parser, operandType, optOperandType,
+ varOperandTypes))
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperandsAndTypes(
+ OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+ OperandRange varOperands, Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveRegions(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+ if (parser.parseRegion(region))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+ if (parser.parseRegion(*varRegion))
+ return failure();
+ varRegions.emplace_back(std::move(varRegion));
+ return success();
+}
+
+void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+ Region ®ion,
+ MutableArrayRef<Region> varRegions) {
+ printer.printRegion(region);
+ if (!varRegions.empty()) {
+ printer << ", ";
+ for (Region ®ion : varRegions)
+ printer.printRegion(region);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+ SmallVectorImpl<Block *> &varSuccessors) {
+ if (parser.parseSuccessor(successor))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ Block *varSuccessor;
+ if (parser.parseSuccessor(varSuccessor))
+ return failure();
+ varSuccessors.append(2, varSuccessor);
+ return success();
+}
+
+void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
+ Block *successor,
+ SuccessorRange varSuccessors) {
+ printer << successor;
+ if (!varSuccessors.empty())
+ printer << ", " << varSuccessors.front();
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
+ IntegerAttr &attr,
+ IntegerAttr &optAttr) {
+ if (parser.parseAttribute(attr))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseAttribute(optAttr))
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
+ Attribute attribute,
+ Attribute optAttribute) {
+ printer << attribute;
+ if (optAttribute)
+ printer << ", " << optAttribute;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
+ NamedAttrList &attrs) {
+ return parser.parseOptionalAttrDict(attrs);
+}
+
+void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+ DictionaryAttr attrs) {
+ printer.printOptionalAttrDict(attrs.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOptionalOperandRef(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ int64_t operandCount = 0;
+ if (parser.parseInteger(operandCount))
+ return failure();
+ bool expectedOptionalOperand = operandCount == 0;
+ return success(expectedOptionalOperand != !!optOperand);
+}
+
+void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
+ Operation *op,
+ Value optOperand) {
+ printer << (optOperand ? "1" : "0");
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomOptionalOperand(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ if (succeeded(parser.parseOptionalLParen())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand) || parser.parseRParen())
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
+ Value optOperand) {
+ if (optOperand)
+ printer << "(" << optOperand << ") ";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+ SmallVector<int64_t> caseValues;
+ while (succeeded(p.parseOptionalKeyword("case"))) {
+ int64_t value;
+ Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
+ if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+ return failure();
+ caseValues.push_back(value);
+ }
+ cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+ return success();
+}
+
+void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
+ for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+ p.printNewline();
+ p << "case " << value << ' ';
+ p.printRegion(*region, /*printEntryBlockArgs=*/false);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
+ return parser.parseLSquare() || parser.parseInteger(value[0]) ||
+ parser.parseComma() || parser.parseInteger(value[1]) ||
+ parser.parseComma() || parser.parseInteger(value[2]) ||
+ parser.parseRSquare();
+}
+
+void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
+ ArrayRef<int64_t> value) {
+ printer << '[' << value << ']';
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
+ return failed(parser.parseInteger(value));
+}
+
+void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t value) {
+ printer << value;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
+ int64_t first) {
+ int64_t sum;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseInteger(second) || parser.parseEqual() ||
+ parser.parseInteger(sum))
+ return true;
+ if (sum != second + first) {
+ parser.emitError(loc, "Expected sum to equal first + second");
+ return true;
+ }
+ return false;
+}
+
+void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t second, int64_t first) {
+ printer << second << " = " << (second + first);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
+ IntegerAttr &result) {
+ if (succeeded(p.parseOptionalKeyword("foo")))
+ return p.parseAttribute(result);
+ return {};
+}
+
+void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
+ IntegerAttr result) {
+ p << "foo ";
+ p.printAttribute(result);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
+ Attribute &attr) {
+ return parser.parseAttribute(attr, type.getValue());
+}
+
+void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
+ Attribute attr) {
+ printer.printAttributeWithoutType(attr);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
new file mode 100644
index 00000000000000..7e9cd834278e34
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
@@ -0,0 +1,211 @@
+//===- TestFormatUtils.h - MLIR Test Dialect Assembly Format Utilities ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTFORMATUTILS_H
+#define MLIR_TESTFORMATUTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+namespace test {
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperands(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands);
+
+void printCustomDirectiveOperands(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Value operand,
+ mlir::Value optOperand,
+ mlir::OperandRange varOperands);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult
+parseCustomDirectiveResults(mlir::OpAsmParser &parser, mlir::Type &operandType,
+ mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveResults(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveWithTypeRefs(
+ mlir::OpAsmParser &parser, mlir::Type operandType,
+ mlir::Type optOperandType,
+ const llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveWithTypeRefs(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands,
+ mlir::Type &operandType, mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value operand,
+ mlir::Value optOperand, mlir::OperandRange varOperands,
+ mlir::Type operandType, mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveRegions(
+ mlir::OpAsmParser &parser, mlir::Region ®ion,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &varRegions);
+
+void printCustomDirectiveRegions(
+ mlir::OpAsmPrinter &printer, mlir::Operation *, mlir::Region ®ion,
+ llvm::MutableArrayRef<mlir::Region> varRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveSuccessors(
+ mlir::OpAsmParser &parser, mlir::Block *&successor,
+ llvm::SmallVectorImpl<mlir::Block *> &varSuccessors);
+
+void printCustomDirectiveSuccessors(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Block *successor,
+ mlir::SuccessorRange varSuccessors);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttributes(mlir::OpAsmParser &parser,
+ mlir::IntegerAttr &attr,
+ mlir::IntegerAttr &optAttr);
+
+void printCustomDirectiveAttributes(mlir::OpAsmPrinter &printer,
+ mlir::Operation *,
+ mlir::Attribute attribute,
+ mlir::Attribute optAttribute);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttrDict(mlir::OpAsmParser &parser,
+ mlir::NamedAttrList &attrs);
+
+void printCustomDirectiveAttrDict(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::DictionaryAttr attrs);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOptionalOperandRef(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomDirectiveOptionalOperandRef(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomOptionalOperand(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomOptionalOperand(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseSwitchCases(
+ mlir::OpAsmParser &p, mlir::DenseI64ArrayAttr &cases,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &caseRegions);
+
+void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::DenseI64ArrayAttr cases,
+ mlir::RegionRange caseRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]);
+
+void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ llvm::ArrayRef<int64_t> value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool parseIntProperty(mlir::OpAsmParser &parser, int64_t &value);
+
+void printIntProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool parseSumProperty(mlir::OpAsmParser &parser, int64_t &second,
+ int64_t first);
+
+void printSumProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t second, int64_t first);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+mlir::OptionalParseResult parseOptionalCustomParser(mlir::AsmParser &p,
+ mlir::IntegerAttr &result);
+
+void printOptionalCustomParser(mlir::AsmPrinter &p, mlir::Operation *,
+ mlir::IntegerAttr result);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
+ mlir::TypeAttr type,
+ mlir::Attribute &attr);
+
+void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
+ mlir::TypeAttr type, mlir::Attribute attr);
+
+} // end namespace test
+
+#endif // MLIR_TESTFORMATUTILS_H
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
index 3673d62bea2c94..dc6413b25707e3 100644
--- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
index 64ec82ecb24ff8..14099bb4bb16ba 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
@@ -6,3 +6,5 @@ bool mlir::TestEffects::Effect::classof(
const mlir::SideEffects::Effect *effect) {
return isa<mlir::TestEffects::Concrete>(effect);
}
+
+#include "TestOpInterfaces.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.h b/mlir/test/lib/Dialect/Test/TestInterfaces.h
index 3239584a93326d..d58d1aafbe66c2 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.h
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.h
@@ -34,4 +34,6 @@ struct Concrete : public Effect::Base<Concrete> {};
} // namespace TestEffects
} // namespace mlir
+#include "TestOpInterfaces.h.inc"
+
#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
new file mode 100644
index 00000000000000..7263774ca158eb
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -0,0 +1,1161 @@
+//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "TestOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return SuccessorOperands(getTargetOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 1)
+ return SuccessorOperands(getFirstOperandsMutable());
+ return SuccessorOperands(getSecondOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestInternalBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(0, getSuccessOperandsMutable());
+ return SuccessorOperands(1, getErrorOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FoldToCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
+ using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FoldToCallOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
+ op.getCalleeAttr(), ValueRange());
+ return success();
+ }
+};
+} // namespace
+
+void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldToCallOpPattern>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// IsolatedRegionOp - test parsing passthrough operands
+//===----------------------------------------------------------------------===//
+
+ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse the input operand.
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(argInfo.ssaName) ||
+ parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
+}
+
+void IsolatedRegionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperand(getOperand());
+ p.shadowRegionArgs(getRegion(), getOperand());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SSACFGRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::SSACFG;
+}
+
+//===----------------------------------------------------------------------===//
+// GraphRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind GraphRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::Graph;
+}
+
+//===----------------------------------------------------------------------===//
+// AffineScopeOp
+//===----------------------------------------------------------------------===//
+
+ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
+}
+
+void AffineScopeOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestRemoveOpWithInnerOps
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestRemoveOpWithInnerOps
+ : public OpRewritePattern<TestOpWithRegionPattern> {
+ using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
+
+ void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
+
+ LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionPattern
+//===----------------------------------------------------------------------===//
+
+void TestOpWithRegionPattern::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<TestRemoveOpWithInnerOps>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
+ return getOperand();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpConstant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVariadicResultsAndFolder
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithVariadicResultsAndFolder::fold(
+ FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
+ for (Value input : this->getOperands()) {
+ results.push_back(input);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpInPlaceFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
+ // Exercise the fact that an operation created with createOrFold should be
+ // allowed to access its parent block.
+ assert(getOperation()->getBlock() &&
+ "expected that operation is not unlinked");
+
+ if (adaptor.getOp() && !getProperties().attr) {
+ // The folder adds "attr" if not present.
+ getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
+ return getResult();
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ inferredReturnTypes.assign({operands[0].getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = operands.front().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto rank = cast<RankedTensorType>(operand.getType()).getRank();
+ auto currShape = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
+ return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
+ currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapePerDimInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto tensorType = cast<RankedTensorType>(operand.getType());
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, tensorType.getRank()),
+ [&](int64_t dim) -> OpFoldResult {
+ return tensorType.isDynamicDim(dim)
+ ? static_cast<OpFoldResult>(
+ builder.createOrFold<tensor::DimOp>(loc, operand,
+ dim))
+ : static_cast<OpFoldResult>(
+ builder.getIndexAttr(tensorType.getDimSize(dim)));
+ }));
+ shapes.emplace_back(std::move(currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SideEffectOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A test resource for side effects.
+struct TestResource : public SideEffects::Resource::Base<TestResource> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
+
+ StringRef getName() final { return "<Test>"; }
+};
+} // namespace
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Check for an effects attribute on the op instance.
+ ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
+ if (!effectsAttr)
+ return;
+
+ // If there is one, it is an array of dictionary attributes that hold
+ // information on the effects of this operation.
+ for (Attribute element : effectsAttr) {
+ DictionaryAttr effectElement = cast<DictionaryAttr>(element);
+
+ // Get the specific memory effect.
+ MemoryEffects::Effect *effect =
+ StringSwitch<MemoryEffects::Effect *>(
+ cast<StringAttr>(effectElement.get("effect")).getValue())
+ .Case("allocate", MemoryEffects::Allocate::get())
+ .Case("free", MemoryEffects::Free::get())
+ .Case("read", MemoryEffects::Read::get())
+ .Case("write", MemoryEffects::Write::get());
+
+ // Check for a non-default resource to use.
+ SideEffects::Resource *resource = SideEffects::DefaultResource::get();
+ if (effectElement.get("test_resource"))
+ resource = TestResource::get();
+
+ // Check for a result to affect.
+ if (effectElement.get("on_result"))
+ effects.emplace_back(effect, getResult(), resource);
+ else if (Attribute ref = effectElement.get("on_reference"))
+ effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
+ else
+ effects.emplace_back(effect, resource);
+ }
+}
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<TestEffects::EffectInstance> &effects) {
+ testSideEffectOpGetEffect(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttrPrettyNameOp
+//===----------------------------------------------------------------------===//
+
+// This op has fancy handling of its SSA result name.
+ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Add the result types.
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
+ result.addTypes(parser.getBuilder().getIntegerType(32));
+
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // If the attribute dictionary contains no 'names' attribute, infer it from
+ // the SSA name (if specified).
+ bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
+ return attr.getName() == "names";
+ });
+
+ // If there was no name specified, check to see if there was a useful name
+ // specified in the asm file.
+ if (hadNames || parser.getNumResults() == 0)
+ return success();
+
+ SmallVector<StringRef, 4> names;
+ auto *context = result.getContext();
+
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
+ auto resultName = parser.getResultName(i);
+ StringRef nameStr;
+ if (!resultName.first.empty() && !isdigit(resultName.first[0]))
+ nameStr = resultName.first;
+
+ names.push_back(nameStr);
+ }
+
+ auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
+ result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
+ return success();
+}
+
+void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
+ // Note that we only need to print the "name" attribute if the asmprinter
+ // result name disagrees with it. This can happen in strange cases, e.g.
+ // when there are conflicts.
+ bool namesDisagree = getNames().size() != getNumResults();
+
+ SmallString<32> resultNameStr;
+ for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
+ resultNameStr.clear();
+ llvm::raw_svector_ostream tmpStream(resultNameStr);
+ p.printOperand(getResult(i), tmpStream);
+
+ auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
+ if (!expectedName ||
+ tmpStream.str().drop_front() != expectedName.getValue()) {
+ namesDisagree = true;
+ }
+ }
+
+ if (namesDisagree)
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+ else
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
+}
+
+// We set the SSA name in the asm syntax to the contents of the name
+// attribute.
+void StringAttrPrettyNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+
+ auto value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.getValue().empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomResultsNameOp
+//===----------------------------------------------------------------------===//
+
+void CustomResultsNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ ArrayAttr value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// ResultTypeWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResultTypeWithTraitOp::verify() {
+ if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+ return success();
+ return emitError("result type should have trait 'TestTypeTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AttrWithTraitOp::verify() {
+ if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
+ return success();
+ return emitError("'attr' attribute should have trait 'TestAttrTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// RegionIfOp
+//===----------------------------------------------------------------------===//
+
+void RegionIfOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printOperands(getOperands());
+ p << ": " << getOperandTypes();
+ p.printArrowTypeList(getResultTypes());
+ p << " then ";
+ p.printRegion(getThenRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " else ";
+ p.printRegion(getElseRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " join ";
+ p.printRegion(getJoinRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+}
+
+ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
+ SmallVector<Type, 2> operandTypes;
+
+ result.regions.reserve(3);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+ Region *joinRegion = result.addRegion();
+
+ // Parse operand, type and arrow type lists.
+ if (parser.parseOperandList(operandInfos) ||
+ parser.parseColonTypeList(operandTypes) ||
+ parser.parseArrowTypeList(result.types))
+ return failure();
+
+ // Parse all attached regions.
+ if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
+ parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
+ parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
+ return failure();
+
+ return parser.resolveOperands(operandInfos, operandTypes,
+ parser.getCurrentLocation(), result.operands);
+}
+
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+ "invalid region index");
+ return getOperands();
+}
+
+void RegionIfOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // We always branch to the join region.
+ if (!point.isParent()) {
+ if (point != getJoinRegion())
+ regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
+ else
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // The then and else regions are the entry regions of this op.
+ regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
+ regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
+}
+
+void RegionIfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ // Each region is invoked at most once.
+ invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
+}
+
+//===----------------------------------------------------------------------===//
+// AnyCondOp
+//===----------------------------------------------------------------------===//
+
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The parent op branches into the only region, and the region branches back
+ // to the parent op.
+ if (point.isParent())
+ regions.emplace_back(&getRegion());
+ else
+ regions.emplace_back(getResults());
+}
+
+void AnyCondOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ invocationBounds.emplace_back(1, 1);
+}
+
+//===----------------------------------------------------------------------===//
+// SingleBlockImplicitTerminatorOp
+//===----------------------------------------------------------------------===//
+
+/// Testing the correctness of some traits.
+static_assert(
+ llvm::is_detected<OpTrait::has_implicit_terminator_t,
+ SingleBlockImplicitTerminatorOp>::value,
+ "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
+static_assert(OpTrait::hasSingleBlockImplicitTerminator<
+ SingleBlockImplicitTerminatorOp>::value,
+ "hasSingleBlockImplicitTerminator does not match "
+ "SingleBlockImplicitTerminatorOp");
+
+//===----------------------------------------------------------------------===//
+// SingleNoTerminatorCustomAsmOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
+ OperationState &state) {
+ Region *body = state.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ return success();
+}
+
+void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
+ printer.printRegion(
+ getRegion(), /*printEntryBlockArgs=*/false,
+ // This op has a single block without terminators. But explicitly mark
+ // as not printing block terminators for testing.
+ /*printBlockTerminators=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestVerifiersOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestVerifiersOp::verify() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ Operation *definingOp = getInput().getDefiningOp();
+ if (definingOp && failed(mlir::verify(definingOp)))
+ return emitOpError("operand hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of verifier");
+
+ return success();
+}
+
+LogicalResult TestVerifiersOp::verifyRegions() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ for (Block &block : getRegion())
+ for (Operation &op : block)
+ if (failed(mlir::verify(&op)))
+ return emitOpError("nested op hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of region verifier");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsOp
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsRegionOp
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse the input argument
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (failed(parser.parseArgument(argInfo)))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs());
+ p << ' ';
+ p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+ /*omitType=*/true);
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ Value arg = getRegion().getArgument(0);
+ setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestIncrementOp
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ APInt one(range.umin().getBitWidth(), 1);
+ setResultRanges(getResult(),
+ {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+ range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+//===----------------------------------------------------------------------===//
+// TestReflectBoundsOp
+
+void TestReflectBoundsOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ MLIRContext *ctx = getContext();
+ Builder b(ctx);
+ setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+ setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+ setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+ setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+ setResultRanges(getResult(), range);
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ConversionFuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+ if (getType() == "EQ")
+ return mlir::presburger::BoundType::EQ;
+ if (getType() == "LB")
+ return mlir::presburger::BoundType::LB;
+ if (getType() == "UB")
+ return mlir::presburger::BoundType::UB;
+ llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+ if (isa<ShapedType>(getVar().getType())) {
+ if (!getDim().has_value())
+ return emitOpError("expected 'dim' attribute for shaped type variable");
+ } else if (getVar().getType().isIndex()) {
+ if (getDim().has_value())
+ return emitOpError("unexpected 'dim' attribute for index variable");
+ } else {
+ return emitOpError("expected index-typed variable or shape type variable");
+ }
+ if (getConstant() && getScalable())
+ return emitOpError("'scalable' and 'constant' are mutually exlusive");
+ if (getScalable() != getVscaleMin().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ if (getScalable() != getVscaleMax().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ return success();
+}
+
+ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
+//===----------------------------------------------------------------------===//
+// CompareOp
+//===----------------------------------------------------------------------===//
+
+ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+ if (getCmp() == "EQ")
+ return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+ if (getCmp() == "LT")
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ if (getCmp() == "LE")
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (getCmp() == "GT")
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (getCmp() == "GE")
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ llvm_unreachable("invalid comparison operator");
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpFoldWithFoldAdaptor
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
+ int64_t sum = 0;
+ if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
+ sum += value.getValue().getSExtValue();
+
+ for (Attribute attr : adaptor.getVariadic())
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 2 * value.getValue().getSExtValue();
+
+ for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
+ for (Attribute attr : attrs)
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 3 * value.getValue().getSExtValue();
+
+ sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
+
+ return IntegerAttr::get(getType(), sum);
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location,
+ OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getX().getType() != adaptor.getY().getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ adaptor.getX().getType(), " vs ",
+ adaptor.getY().getType());
+ }
+ inferredReturnTypes.assign({adaptor.getX().getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithRefineTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+// TODO: We should be able to only define either inferReturnType or
+// refineReturnType, currently only refineReturnType can be omitted.
+LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ returnTypes.clear();
+ return OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ context, location, operands, attributes, properties, regions,
+ returnTypes);
+}
+
+LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ // TODO: Add helper to make this more concise to write.
+ if (returnTypes.empty())
+ returnTypes.resize(1, nullptr);
+ if (returnTypes[0] && returnTypes[0] != operands[0].getType())
+ return emitOptionalError(location,
+ "required first operand and result to match");
+ returnTypes[0] = operands[0].getType();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = adaptor.getOperand1().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithPropertiesAndInferredType
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
+ MLIRContext *context, std::optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+
+ Adaptor adaptor(operands, attributes, properties, regions);
+ inferredReturnTypes.push_back(IntegerType::get(
+ context, adaptor.getLhs() + adaptor.getProperties().rhs));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockOp
+//===----------------------------------------------------------------------===//
+
+void LoopBlockOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ regions.emplace_back(&getBody(), getBody().getArguments());
+ if (point.isParent())
+ return;
+
+ regions.emplace_back((*this)->getResults());
+}
+
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
+ return MutableOperandRange(getInitMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockTerminatorOp
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
+ return getExitArgMutable();
+ return getNextIterArgMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// SwitchWithNoBreakOp
+//===----------------------------------------------------------------------===//
+
+void TestNoTerminatorOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
+ // Just a simple fold for testing purposes that reads an operands constant
+ // value and returns it.
+ if (!attributes.empty())
+ return attributes.front();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Tensor/Buffer Ops
+//===----------------------------------------------------------------------===//
+
+void ReadBufferOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // The buffer operand is read.
+ effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
+ SideEffects::DefaultResource::get());
+ // The buffer contents are dumped.
+ effects.emplace_back(MemoryEffects::Write::get(),
+ SideEffects::DefaultResource::get());
+}
+
+//===----------------------------------------------------------------------===//
+// Test Dataflow
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestCallAndStoreOp
+
+CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallAndStoreOp::getArgOperands() {
+ return getCalleeOperands();
+}
+
+MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
+ return getCalleeOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOnDeviceOp
+
+CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
+ return getForwardedOperands();
+}
+
+MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
+ return getForwardedOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithARegion
+
+void TestStoreWithARegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody(), getBody().front().getArguments());
+ else
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithALoopRegion
+
+void TestStoreWithALoopRegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for the operation not to
+ // enter the body.
+ regions.emplace_back(
+ RegionSuccessor(&getBody(), getBody().front().getArguments()));
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestVersionedOpA
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ auto &prop = state.getOrAddProperties<Properties>();
+ if (mlir::failed(reader.readAttribute(prop.dims)))
+ return mlir::failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ return success();
+ }
+ }
+
+ if (mlir::failed(reader.readAttribute(prop.modifier)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
+ auto &prop = getProperties();
+ writer.writeAttribute(prop.dims);
+
+ auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to write.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ llvm::outs() << "downgrading op properties...\n";
+ return;
+ }
+ }
+ writer.writeAttribute(prop.modifier);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVersionedProperties
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
+ mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
+ uint64_t value1, value2 = 0;
+ if (failed(reader.readVarInt(value1)))
+ return failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ bool needToParseAnotherInt = true;
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2))
+ needToParseAnotherInt = false;
+ }
+ if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
+ return failure();
+
+ prop.value1 = value1;
+ prop.value2 = value2;
+ return success();
+}
+
+void TestOpWithVersionedProperties::writeToMlirBytecode(
+ mlir::DialectBytecodeWriter &writer,
+ const test::VersionedProperties &prop) {
+ writer.writeVarInt(prop.value1);
+ writer.writeVarInt(prop.value2);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp
new file mode 100644
index 00000000000000..47d5b1b19121ef
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.cpp
@@ -0,0 +1,17 @@
+//===- TestOps.cpp - MLIR Test Dialect Operations ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOps.h"
+#include "TestDialect.h"
+#include "TestFormatUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+using namespace mlir;
+using namespace test;
+
+#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
new file mode 100644
index 00000000000000..f9925855bb9db6
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -0,0 +1,149 @@
+//===- TestOps.h - MLIR Test Dialect Operations ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTOPS_H
+#define MLIR_TESTOPS_H
+
+#include "TestAttributes.h"
+#include "TestInterfaces.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/DLTI/Traits.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace test {
+class TestDialect;
+
+//===----------------------------------------------------------------------===//
+// TestResource
+//===----------------------------------------------------------------------===//
+
+/// A test resource for side effects.
+struct TestResource : public mlir::SideEffects::Resource::Base<TestResource> {
+ llvm::StringRef getName() final { return "<Test>"; }
+};
+
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+struct PropertiesWithCustomPrint {
+ /// A shared_ptr to a const object is safe: it is equivalent to a value-based
+ /// member. Here the label will be deallocated when the last operation
+ /// refering to it is destroyed. However there is no pool-allocation: this is
+ /// offloaded to the client.
+ std::shared_ptr<const std::string> label;
+ int value;
+ bool operator==(const PropertiesWithCustomPrint &rhs) const {
+ return value == rhs.value && *label == *rhs.label;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ PropertiesWithCustomPrint &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr
+getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop);
+llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop);
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
+class MyPropStruct {
+public:
+ std::string content;
+ // These three methods are invoked through the `MyStructProperty` wrapper
+ // defined in TestOps.td
+ mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
+ static mlir::LogicalResult
+ setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+ llvm::hash_code hash() const;
+ bool operator==(const MyPropStruct &rhs) const {
+ return content == rhs.content;
+ }
+};
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ MyPropStruct &prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ MyPropStruct &prop);
+
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+struct VersionedProperties {
+ // For the sake of testing, assume that this object was associated to version
+ // 1.2 of the test dialect when having only one int value. In the current
+ // version 2.0, the property has two values. We also assume that the class is
+ // upgrade-able if value2 = 0.
+ int value1;
+ int value2;
+ bool operator==(const VersionedProperties &rhs) const {
+ return value1 == rhs.value1 && value2 == rhs.value2;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ VersionedProperties &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const VersionedProperties &prop);
+llvm::hash_code computeHash(const VersionedProperties &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const VersionedProperties &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ VersionedProperties &prop);
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ llvm::MutableArrayRef<int64_t> prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ llvm::ArrayRef<int64_t> prop);
+
+} // namespace test
+
+#define GET_OP_CLASSES
+#include "TestOps.h.inc"
+
+#endif // MLIR_TESTOPS_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index 84e6a43655cacd..c376d6c73c6452 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -8,6 +8,7 @@
#include "TestOpsSyntax.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Base64.h"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..0c1731ba5f07c8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index fa093cafcb0dc3..57e7d658fb501f 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index d9b67ef95ace83..031e1062dac76d 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 7a195eb25a3ba1..1593b6d7d7534b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -139,6 +139,7 @@ static void printBarString(AsmPrinter &printer, StringRef foo) {
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
+#include "TestTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index b1b5921d8faddd..da5604944d5a3b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -31,11 +31,11 @@ class TestAttrWithFormatAttr;
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
struct FieldInfo {
- ::llvm::StringRef name;
- ::mlir::Type type;
+ llvm::StringRef name;
+ mlir::Type type;
// Custom allocation called from generated constructor code
- FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+ FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const {
return FieldInfo{alloc.copyInto(name), type};
}
};
diff --git a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index e668224d343234..4894ad5294990a 100644
--- a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
index 7b18f219b915f4..b742b316c77126 100644
--- a/mlir/test/lib/IR/TestClone.cpp
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index 09ad1363228243..8e13dd9751398c 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 0e1368f2e0ecaf..b470b15c533b57 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp
index 2bd63a48f77d1d..c6bce111d3ea7f 100644
--- a/mlir/test/lib/IR/TestTypes.cpp
+++ b/mlir/test/lib/IR/TestTypes.cpp
@@ -8,6 +8,7 @@
#include "TestTypes.h"
#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
index 00148df26e3512..4556671df0ba0b 100644
--- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 477b75916f80c8..2762e254903245 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp
index 9821179d05e891..223cc78dd1e21d 100644
--- a/mlir/test/lib/Transforms/TestInlining.cpp
+++ b/mlir/test/lib/Transforms/TestInlining.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
index 61e1fbcf3feaf3..82fa6cdb68d23c 100644
--- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
+++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index 66ce53bbbadec9..0a5fa8d3c475c3 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestOpsSyntax.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp
index 83627975006ee8..b88009d1e3c361 100644
--- a/mlir/unittests/IR/IRMapping.cpp
+++ b/mlir/unittests/IR/IRMapping.cpp
@@ -11,6 +11,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 58049a9969e3ab..b6066dd5685dc6 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -19,6 +19,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
#include "mlir/IR/OwningOpRef.h"
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 5ab4d9a106231a..42196b003e7dad 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -15,6 +15,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9d75615b39c0c1..f94dc784458077 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/OperationSupport.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 30b72618e45f0b..75d5228c82d99b 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -10,6 +10,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 52347dcabe0381..c83ac9088114ce 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index dc5f4047c286db..489aaebd0453a5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -4,7 +4,7 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("//llvm:lit_test.bzl", "package_path")
-load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library", "gentbl_sharded_ops", "td_library")
package(
default_visibility = ["//visibility:public"],
@@ -151,14 +151,6 @@ gentbl_cc_library(
name = "TestOpsIncGen",
strip_include_prefix = "lib/Dialect/Test",
tbl_outs = [
- (
- ["-gen-op-decls"],
- "lib/Dialect/Test/TestOps.h.inc",
- ),
- (
- ["-gen-op-defs"],
- "lib/Dialect/Test/TestOps.cpp.inc",
- ),
(
[
"-gen-dialect-decls",
@@ -370,12 +362,29 @@ cc_library(
],
)
+gentbl_sharded_ops(
+ name = "TestDialectOpSrcs",
+ hdr_out = "lib/Dialect/Test/TestOps.h.inc",
+ shard_count = 20,
+ sharder = "//mlir:mlir-src-sharder",
+ src_file = "lib/Dialect/Test/TestOps.cpp",
+ src_out = "lib/Dialect/Test/TestOps.cpp.inc",
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Test/TestOps.td",
+ test = True,
+ deps = [":TestOpTdFiles"],
+)
+
cc_library(
name = "TestDialect",
+<<<<<<< HEAD
srcs = glob(
["lib/Dialect/Test/*.cpp"],
exclude = ["lib/Dialect/Test/TestToLLVMIRTranslation.cpp"],
),
+=======
+ srcs = glob(["lib/Dialect/Test/*.cpp"]) + [":TestDialectOpSrcs"],
+>>>>>>> 95300a676d75 ([mlir][test] Shard and reorganize the test dialect)
hdrs = glob(["lib/Dialect/Test/*.h"]),
includes = [
"lib/Dialect/Test",
More information about the llvm-branch-commits
mailing list