[Mlir-commits] [mlir] e95e94a - [mlir][test] Reorganize the test dialect (#89424)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 22 13:42:09 PDT 2024
Author: Jeff Niu
Date: 2024-04-22T13:42:05-07:00
New Revision: e95e94adc6bb748de015ac3053e7f0786b65f351
URL: https://github.com/llvm/llvm-project/commit/e95e94adc6bb748de015ac3053e7f0786b65f351
DIFF: https://github.com/llvm/llvm-project/commit/e95e94adc6bb748de015ac3053e7f0786b65f351.diff
LOG: [mlir][test] Reorganize the test dialect (#89424)
This PR massively reorganizes the Test dialect's source files. It moves
manually-written op hooks into `TestOpDefs.cpp`, moves format custom
directive parsers and printers into `TestFormatUtils`, adds missing
comment blocks, and moves around where generated source files are
included for types, attributes, enums, etc. into their own source file.
This will hopefully help navigate the test dialect source code, but also
speeds up compile time of the test dialect by putting generated source
files into separate compilation units.
This also sets up the test dialect to shard its op definitions, done in
the next PR.
Added:
mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
mlir/test/lib/Dialect/Test/TestFormatUtils.h
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.cpp
mlir/test/lib/Dialect/Test/TestOps.h
Modified:
mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
mlir/test/lib/Dialect/Test/TestInterfaces.cpp
mlir/test/lib/Dialect/Test/TestInterfaces.h
mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
mlir/test/lib/Dialect/Test/TestTraits.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
mlir/test/lib/IR/TestClone.cpp
mlir/test/lib/IR/TestSideEffects.cpp
mlir/test/lib/IR/TestSymbolUses.cpp
mlir/test/lib/IR/TestTypes.cpp
mlir/test/lib/IR/TestVisitorsGeneric.cpp
mlir/test/lib/Pass/TestPassManager.cpp
mlir/test/lib/Transforms/TestInlining.cpp
mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
mlir/unittests/IR/AdaptorTest.cpp
mlir/unittests/IR/IRMapping.cpp
mlir/unittests/IR/InterfaceAttachmentTest.cpp
mlir/unittests/IR/InterfaceTest.cpp
mlir/unittests/IR/OperationSupportTest.cpp
mlir/unittests/IR/PatternMatchTest.cpp
mlir/unittests/TableGen/OpBuildGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index ca052392f2f5f2..65592a5c5d698b 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 29480f5ad63ee0..3f9ce2dc0bc50a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
index 5e17779660f392..f878a262512ee8 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 3c4067b35d8e5b..cc1af59c5e15bb 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index b098a5a23fd316..34513cd418e4c2 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index 84f45b31603192..56f309f150ca5d 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 10aba733bd5696..0d7dce2240f4cb 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index d246c0492a3bd5..f63e4d330e6ac1 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -47,7 +47,10 @@ add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
TestDialect.cpp
+ TestFormatUtils.cpp
TestInterfaces.cpp
+ TestOpDefs.cpp
+ TestOps.cpp
TestPatterns.cpp
TestTraits.cpp
TestTypes.cpp
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d41d495c38e553..2cc051e664beec 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
@@ -244,7 +245,7 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
//===----------------------------------------------------------------------===//
#include "TestAttrInterfaces.cpp.inc"
-
+#include "TestOpEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a23ed89c4b04d1..77fd7e61bd3a06 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -7,8 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
-#include "TestAttributes.h"
-#include "TestInterfaces.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -39,17 +38,85 @@
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
#include <numeric>
#include <optional>
-// Include this before the using namespace lines below to
-// test that we don't have namespace dependencies.
+// Include this before the using namespace lines below to test that we don't
+// have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
using namespace mlir;
using namespace test;
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
+ Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set TestProperties";
+ return failure();
+ }
+ auto label = dict.getAs<mlir::StringAttr>("label");
+ if (!label) {
+ emitError() << "expected StringAttr for key `label`";
+ return failure();
+ }
+ auto valueAttr = dict.getAs<IntegerAttr>("value");
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+
+ prop.label = std::make_shared<std::string>(label.getValue());
+ prop.value = valueAttr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr
+test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
+ attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
+ return llvm::hash_combine(prop.value, StringRef(*prop.label));
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop) {
+ p.printKeywordOrString(*prop.label);
+ p << " is " << prop.value;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop) {
+ std::string label;
+ if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
+ parser.parseInteger(prop.value))
+ return failure();
+ prop.label = std::make_shared<std::string>(std::move(label));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
@@ -70,8 +137,8 @@ llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MyPropStruct &prop) {
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MyPropStruct &prop) {
StringRef str;
if (failed(reader.readString(str)))
return failure();
@@ -79,13 +146,71 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- MyPropStruct &prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ MyPropStruct &prop) {
writer.writeOwnedString(prop.content);
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MutableArrayRef<int64_t> prop) {
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set VersionedProperties";
+ return failure();
+ }
+ auto value1Attr = dict.getAs<IntegerAttr>("value1");
+ if (!value1Attr) {
+ emitError() << "expected IntegerAttr for key `value1`";
+ return failure();
+ }
+ auto value2Attr = dict.getAs<IntegerAttr>("value2");
+ if (!value2Attr) {
+ emitError() << "expected IntegerAttr for key `value2`";
+ return failure();
+ }
+
+ prop.value1 = value1Attr.getValue().getSExtValue();
+ prop.value2 = value2Attr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const VersionedProperties &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
+ attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const VersionedProperties &prop) {
+ return llvm::hash_combine(prop.value1, prop.value2);
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const VersionedProperties &prop) {
+ p << prop.value1 << " | " << prop.value2;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ VersionedProperties &prop) {
+ if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
+ parser.parseInteger(prop.value2))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MutableArrayRef<int64_t> prop) {
uint64_t size;
if (failed(reader.readVarInt(size)))
return failure();
@@ -101,45 +226,13 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- ArrayRef<int64_t> prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ ArrayRef<int64_t> prop) {
writer.writeVarInt(prop.size());
for (auto elt : prop)
writer.writeVarInt(elt);
}
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop);
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop);
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx,
- const VersionedProperties &prop);
-static llvm::hash_code computeHash(const VersionedProperties &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop);
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions);
-
-void test::registerTestDialect(DialectRegistry ®istry) {
- registry.insert<TestDialect>();
-}
-
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
@@ -196,9 +289,20 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
// TestDialect
//===----------------------------------------------------------------------===//
-static void testSideEffectOpGetEffect(
+void test::registerTestDialect(DialectRegistry ®istry) {
+ registry.insert<TestDialect>();
+}
+
+void test::testSideEffectOpGetEffect(
Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
+ SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
+ &effects) {
+ auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
+ if (!effectsAttr)
+ return;
+
+ effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
+}
// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
struct TestOpEffectInterfaceFallback
@@ -318,57 +422,6 @@ TestDialect::getOperationPrinter(Operation *op) const {
return {};
}
-//===----------------------------------------------------------------------===//
-// TypedAttrOp
-//===----------------------------------------------------------------------===//
-
-/// Parse an attribute with a given type.
-static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
- Attribute &attr) {
- return parser.parseAttribute(attr, type.getValue());
-}
-
-/// Print an attribute without its type.
-static void printAttrElideType(AsmPrinter &printer, Operation *op,
- TypeAttr type, Attribute attr) {
- printer.printAttributeWithoutType(attr);
-}
-
-//===----------------------------------------------------------------------===//
-// TestBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
- assert(index == 0 && "invalid successor index");
- return SuccessorOperands(getTargetOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 1)
- return SuccessorOperands(getFirstOperandsMutable());
- return SuccessorOperands(getSecondOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 0)
- return SuccessorOperands(0, getSuccessOperandsMutable());
- return SuccessorOperands(1, getErrorOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestDialectCanonicalizerOp
-//===----------------------------------------------------------------------===//
-
static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
PatternRewriter &rewriter) {
@@ -381,1206 +434,3 @@ void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
-
-//===----------------------------------------------------------------------===//
-// TestCallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // Check that the callee attribute was specified.
- auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' symbol reference attribute");
- if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
- return emitOpError() << "'" << fnAttr.getValue()
- << "' does not reference a valid function";
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ConversionFuncOp
-//===----------------------------------------------------------------------===//
-
-ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
- OperationState &result) {
- auto buildFuncType =
- [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
- function_interface_impl::VariadicFlag,
- std::string &) { return builder.getFunctionType(argTypes, results); };
-
- return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
-}
-
-void ConversionFuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(
- p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
- getArgAttrsAttrName(), getResAttrsAttrName());
-}
-
-//===----------------------------------------------------------------------===//
-// TestFoldToCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
- using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(FoldToCallOp op,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
- op.getCalleeAttr(), ValueRange());
- return success();
- }
-};
-} // namespace
-
-void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldToCallOpPattern>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// Test IsolatedRegionOp - parse passthrough region arguments.
-//===----------------------------------------------------------------------===//
-
-ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Parse the input operand.
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (parser.parseOperand(argInfo.ssaName) ||
- parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
-}
-
-void IsolatedRegionOp::print(OpAsmPrinter &p) {
- p << ' ';
- p.printOperand(getOperand());
- p.shadowRegionArgs(getRegion(), getOperand());
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test SSACFGRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
- return RegionKind::SSACFG;
-}
-
-//===----------------------------------------------------------------------===//
-// Test GraphRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind GraphRegionOp::getRegionKind(unsigned index) {
- return RegionKind::Graph;
-}
-
-//===----------------------------------------------------------------------===//
-// Test AffineScopeOp
-//===----------------------------------------------------------------------===//
-
-ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
-}
-
-void AffineScopeOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test OptionalCustomAttrOp
-//===----------------------------------------------------------------------===//
-
-static OptionalParseResult parseOptionalCustomParser(AsmParser &p,
- IntegerAttr &result) {
- if (succeeded(p.parseOptionalKeyword("foo")))
- return p.parseAttribute(result);
- return {};
-}
-
-static void printOptionalCustomParser(AsmPrinter &p, Operation *,
- IntegerAttr result) {
- p << "foo ";
- p.printAttribute(result);
-}
-
-//===----------------------------------------------------------------------===//
-// ReifyBoundOp
-//===----------------------------------------------------------------------===//
-
-::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
- if (getType() == "EQ")
- return ::mlir::presburger::BoundType::EQ;
- if (getType() == "LB")
- return ::mlir::presburger::BoundType::LB;
- if (getType() == "UB")
- return ::mlir::presburger::BoundType::UB;
- llvm_unreachable("invalid bound type");
-}
-
-LogicalResult ReifyBoundOp::verify() {
- if (isa<ShapedType>(getVar().getType())) {
- if (!getDim().has_value())
- return emitOpError("expected 'dim' attribute for shaped type variable");
- } else if (getVar().getType().isIndex()) {
- if (getDim().has_value())
- return emitOpError("unexpected 'dim' attribute for index variable");
- } else {
- return emitOpError("expected index-typed variable or shape type variable");
- }
- if (getConstant() && getScalable())
- return emitOpError("'scalable' and 'constant' are mutually exlusive");
- if (getScalable() != getVscaleMin().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- if (getScalable() != getVscaleMax().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- return success();
-}
-
-::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
- if (getDim().has_value())
- return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
- return ValueBoundsConstraintSet::Variable(getVar());
-}
-
-::mlir::ValueBoundsConstraintSet::ComparisonOperator
-CompareOp::getComparisonOperator() {
- if (getCmp() == "EQ")
- return ValueBoundsConstraintSet::ComparisonOperator::EQ;
- if (getCmp() == "LT")
- return ValueBoundsConstraintSet::ComparisonOperator::LT;
- if (getCmp() == "LE")
- return ValueBoundsConstraintSet::ComparisonOperator::LE;
- if (getCmp() == "GT")
- return ValueBoundsConstraintSet::ComparisonOperator::GT;
- if (getCmp() == "GE")
- return ValueBoundsConstraintSet::ComparisonOperator::GE;
- llvm_unreachable("invalid comparison operator");
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
- if (!getLhsMap())
- return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(0, getLhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
- int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- if (!getRhsMap())
- return ValueBoundsConstraintSet::Variable(
- getVarOperands()[rhsOperandsBegin]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
-}
-
-LogicalResult CompareOp::verify() {
- if (getCompose() && (getLhsMap() || getRhsMap()))
- return emitOpError(
- "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
- int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
- if (getVarOperands().size() != size_t(expectedNumOperands))
- return emitOpError("expected ")
- << expectedNumOperands << " operands, but got "
- << getVarOperands().size();
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test removing op with inner ops.
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct TestRemoveOpWithInnerOps
- : public OpRewritePattern<TestOpWithRegionPattern> {
- using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
-
- void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
-
- LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
- PatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-} // namespace
-
-void TestOpWithRegionPattern::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<TestRemoveOpWithInnerOps>(context);
-}
-
-OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
- return getOperand();
-}
-
-OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
-
-LogicalResult TestOpWithVariadicResultsAndFolder::fold(
- FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
- for (Value input : this->getOperands()) {
- results.push_back(input);
- }
- return success();
-}
-
-OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
- // Exercise the fact that an operation created with createOrFold should be
- // allowed to access its parent block.
- assert(getOperation()->getBlock() &&
- "expected that operation is not unlinked");
-
- if (adaptor.getOp() && !getProperties().attr) {
- // The folder adds "attr" if not present.
- getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
- return getResult();
- }
- return {};
-}
-
-OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
- int64_t sum = 0;
- if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
- sum += value.getValue().getSExtValue();
-
- for (Attribute attr : adaptor.getVariadic())
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 2 * value.getValue().getSExtValue();
-
- for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
- for (Attribute attr : attrs)
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 3 * value.getValue().getSExtValue();
-
- sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
-
- return IntegerAttr::get(getType(), sum);
-}
-
-LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- inferredReturnTypes.assign({operands[0].getType()});
- return success();
-}
-
-LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location,
- OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (adaptor.getX().getType() != adaptor.getY().getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- adaptor.getX().getType(), " vs ",
- adaptor.getY().getType());
- }
- inferredReturnTypes.assign({adaptor.getX().getType()});
- return success();
-}
-
-// TODO: We should be able to only define either inferReturnType or
-// refineReturnType, currently only refineReturnType can be omitted.
-LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- returnTypes.clear();
- return OpWithRefineTypeInterfaceOp::refineReturnTypes(
- context, location, operands, attributes, properties, regions,
- returnTypes);
-}
-
-LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- // TODO: Add helper to make this more concise to write.
- if (returnTypes.empty())
- returnTypes.resize(1, nullptr);
- if (returnTypes[0] && returnTypes[0] != operands[0].getType())
- return emitOptionalError(location,
- "required first operand and result to match");
- returnTypes[0] = operands[0].getType();
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = operands.front().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = adaptor.getOperand1().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- Location loc = getLoc();
- shapes.reserve(operands.size());
- for (Value operand : llvm::reverse(operands)) {
- auto rank = cast<RankedTensorType>(operand.getType()).getRank();
- auto currShape = llvm::to_vector<4>(
- llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
- return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
- }));
- shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
- currShape));
- }
- return success();
-}
-
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
- OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
- Location loc = getLoc();
- shapes.reserve(getNumOperands());
- for (Value operand : llvm::reverse(getOperands())) {
- auto tensorType = cast<RankedTensorType>(operand.getType());
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, tensorType.getRank()),
- [&](int64_t dim) -> OpFoldResult {
- return tensorType.isDynamicDim(dim)
- ? static_cast<OpFoldResult>(
- builder.createOrFold<tensor::DimOp>(loc, operand,
- dim))
- : static_cast<OpFoldResult>(
- builder.getIndexAttr(tensorType.getDimSize(dim)));
- }));
- shapes.emplace_back(std::move(currShape));
- }
- return success();
-}
-
-LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
- MLIRContext *context, std::optional<Location>, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
-
- Adaptor adaptor(operands, attributes, properties, regions);
- inferredReturnTypes.push_back(IntegerType::get(
- context, adaptor.getLhs() + adaptor.getProperties().rhs));
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test SideEffect interfaces
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// A test resource for side effects.
-struct TestResource : public SideEffects::Resource::Base<TestResource> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
-
- StringRef getName() final { return "<Test>"; }
-};
-} // namespace
-
-static void testSideEffectOpGetEffect(
- Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
- &effects) {
- auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
- if (!effectsAttr)
- return;
-
- effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- // Check for an effects attribute on the op instance.
- ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
- if (!effectsAttr)
- return;
-
- // If there is one, it is an array of dictionary attributes that hold
- // information on the effects of this operation.
- for (Attribute element : effectsAttr) {
- DictionaryAttr effectElement = cast<DictionaryAttr>(element);
-
- // Get the specific memory effect.
- MemoryEffects::Effect *effect =
- StringSwitch<MemoryEffects::Effect *>(
- cast<StringAttr>(effectElement.get("effect")).getValue())
- .Case("allocate", MemoryEffects::Allocate::get())
- .Case("free", MemoryEffects::Free::get())
- .Case("read", MemoryEffects::Read::get())
- .Case("write", MemoryEffects::Write::get());
-
- // Check for a non-default resource to use.
- SideEffects::Resource *resource = SideEffects::DefaultResource::get();
- if (effectElement.get("test_resource"))
- resource = TestResource::get();
-
- // Check for a result to affect.
- if (effectElement.get("on_result"))
- effects.emplace_back(effect, getResult(), resource);
- else if (Attribute ref = effectElement.get("on_reference"))
- effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
- else
- effects.emplace_back(effect, resource);
- }
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<TestEffects::EffectInstance> &effects) {
- testSideEffectOpGetEffect(getOperation(), effects);
-}
-
-//===----------------------------------------------------------------------===//
-// StringAttrPrettyNameOp
-//===----------------------------------------------------------------------===//
-
-// This op has fancy handling of its SSA result name.
-ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Add the result types.
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
- result.addTypes(parser.getBuilder().getIntegerType(32));
-
- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
- return failure();
-
- // If the attribute dictionary contains no 'names' attribute, infer it from
- // the SSA name (if specified).
- bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
- return attr.getName() == "names";
- });
-
- // If there was no name specified, check to see if there was a useful name
- // specified in the asm file.
- if (hadNames || parser.getNumResults() == 0)
- return success();
-
- SmallVector<StringRef, 4> names;
- auto *context = result.getContext();
-
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
- auto resultName = parser.getResultName(i);
- StringRef nameStr;
- if (!resultName.first.empty() && !isdigit(resultName.first[0]))
- nameStr = resultName.first;
-
- names.push_back(nameStr);
- }
-
- auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
- result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
- return success();
-}
-
-void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
- // Note that we only need to print the "name" attribute if the asmprinter
- // result name disagrees with it. This can happen in strange cases, e.g.
- // when there are conflicts.
- bool namesDisagree = getNames().size() != getNumResults();
-
- SmallString<32> resultNameStr;
- for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
- resultNameStr.clear();
- llvm::raw_svector_ostream tmpStream(resultNameStr);
- p.printOperand(getResult(i), tmpStream);
-
- auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
- if (!expectedName ||
- tmpStream.str().drop_front() != expectedName.getValue()) {
- namesDisagree = true;
- }
- }
-
- if (namesDisagree)
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
- else
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
-}
-
-// We set the SSA name in the asm syntax to the contents of the name
-// attribute.
-void StringAttrPrettyNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
-
- auto value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.getValue().empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-void CustomResultsNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- ArrayAttr value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// ResultTypeWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ResultTypeWithTraitOp::verify() {
- if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
- return success();
- return emitError("result type should have trait 'TestTypeTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// AttrWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult AttrWithTraitOp::verify() {
- if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
- return success();
- return emitError("'attr' attribute should have trait 'TestAttrTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// RegionIfOp
-//===----------------------------------------------------------------------===//
-
-void RegionIfOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printOperands(getOperands());
- p << ": " << getOperandTypes();
- p.printArrowTypeList(getResultTypes());
- p << " then ";
- p.printRegion(getThenRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " else ";
- p.printRegion(getElseRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " join ";
- p.printRegion(getJoinRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
-}
-
-ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
- SmallVector<Type, 2> operandTypes;
-
- result.regions.reserve(3);
- Region *thenRegion = result.addRegion();
- Region *elseRegion = result.addRegion();
- Region *joinRegion = result.addRegion();
-
- // Parse operand, type and arrow type lists.
- if (parser.parseOperandList(operandInfos) ||
- parser.parseColonTypeList(operandTypes) ||
- parser.parseArrowTypeList(result.types))
- return failure();
-
- // Parse all attached regions.
- if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
- parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
- parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
- return failure();
-
- return parser.resolveOperands(operandInfos, operandTypes,
- parser.getCurrentLocation(), result.operands);
-}
-
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
- "invalid region index");
- return getOperands();
-}
-
-void RegionIfOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // We always branch to the join region.
- if (!point.isParent()) {
- if (point != getJoinRegion())
- regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
- else
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
- // The then and else regions are the entry regions of this op.
- regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
- regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
-}
-
-void RegionIfOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- // Each region is invoked at most once.
- invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
-}
-
-//===----------------------------------------------------------------------===//
-// AnyCondOp
-//===----------------------------------------------------------------------===//
-
-void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
- SmallVectorImpl<RegionSuccessor> ®ions) {
- // The parent op branches into the only region, and the region branches back
- // to the parent op.
- if (point.isParent())
- regions.emplace_back(&getRegion());
- else
- regions.emplace_back(getResults());
-}
-
-void AnyCondOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- invocationBounds.emplace_back(1, 1);
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockOp
-//===----------------------------------------------------------------------===//
-
-void LoopBlockOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- regions.emplace_back(&getBody(), getBody().getArguments());
- if (point.isParent())
- return;
-
- regions.emplace_back((*this)->getResults());
-}
-
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
- return MutableOperandRange(getInitMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockTerminatorOp
-//===----------------------------------------------------------------------===//
-
-MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
- return getExitArgMutable();
- return getNextIterArgMutable();
-}
-
-//===----------------------------------------------------------------------===//
-// SwitchWithNoBreakOp
-//===----------------------------------------------------------------------===//
-
-void TestNoTerminatorOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
-
-//===----------------------------------------------------------------------===//
-// SingleNoTerminatorCustomAsmOp
-//===----------------------------------------------------------------------===//
-
-ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
- OperationState &state) {
- Region *body = state.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
- return success();
-}
-
-void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
- printer.printRegion(
- getRegion(), /*printEntryBlockArgs=*/false,
- // This op has a single block without terminators. But explicitly mark
- // as not printing block terminators for testing.
- /*printBlockTerminators=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// TestVerifiersOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestVerifiersOp::verify() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- Operation *definingOp = getInput().getDefiningOp();
- if (definingOp && failed(mlir::verify(definingOp)))
- return emitOpError("operand hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of verifier");
-
- return success();
-}
-
-LogicalResult TestVerifiersOp::verifyRegions() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- for (Block &block : getRegion())
- for (Operation &op : block)
- if (failed(mlir::verify(&op)))
- return emitOpError("nested op hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of region verifier");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test InferIntRangeInterface
-//===----------------------------------------------------------------------===//
-
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- // Parse the input argument
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (failed(parser.parseArgument(argInfo)))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
-}
-
-void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
- p.printOptionalAttrDict((*this)->getAttrs());
- p << ' ';
- p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
- /*omitType=*/true);
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-void TestWithBoundsRegionOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- Value arg = getRegion().getArgument(0);
- setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- APInt one(range.umin().getBitWidth(), 1);
- setResultRanges(getResult(),
- {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
- range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
-}
-
-void TestReflectBoundsOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- MLIRContext *ctx = getContext();
- Builder b(ctx);
- setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
- setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
- setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
- setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
- setResultRanges(getResult(), range);
-}
-
-OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
- // Just a simple fold for testing purposes that reads an operands constant
- // value and returns it.
- if (!attributes.empty())
- return attributes.front();
- return nullptr;
-}
-
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set TestProperties";
- return failure();
- }
- auto label = dict.getAs<mlir::StringAttr>("label");
- if (!label) {
- emitError() << "expected StringAttr for key `label`";
- return failure();
- }
- auto valueAttr = dict.getAs<IntegerAttr>("value");
- if (!valueAttr) {
- emitError() << "expected IntegerAttr for key `value`";
- return failure();
- }
-
- prop.label = std::make_shared<std::string>(label.getValue());
- prop.value = valueAttr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
- attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
- return llvm::hash_combine(prop.value, StringRef(*prop.label));
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop) {
- p.printKeywordOrString(*prop.label);
- p << " is " << prop.value;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop) {
- std::string label;
- if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
- parser.parseInteger(prop.value))
- return failure();
- prop.label = std::make_shared<std::string>(std::move(label));
- return success();
-}
-
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
- SmallVector<int64_t> caseValues;
- while (succeeded(p.parseOptionalKeyword("case"))) {
- int64_t value;
- Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
- if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
- return failure();
- caseValues.push_back(value);
- }
- cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
- return success();
-}
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions) {
- for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
- p.printNewline();
- p << "case " << value << ' ';
- p.printRegion(*region, /*printEntryBlockArgs=*/false);
- }
-}
-
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set VersionedProperties";
- return failure();
- }
- auto value1Attr = dict.getAs<IntegerAttr>("value1");
- if (!value1Attr) {
- emitError() << "expected IntegerAttr for key `value1`";
- return failure();
- }
- auto value2Attr = dict.getAs<IntegerAttr>("value2");
- if (!value2Attr) {
- emitError() << "expected IntegerAttr for key `value2`";
- return failure();
- }
-
- prop.value1 = value1Attr.getValue().getSExtValue();
- prop.value2 = value2Attr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
- attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const VersionedProperties &prop) {
- return llvm::hash_combine(prop.value1, prop.value2);
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop) {
- p << prop.value1 << " | " << prop.value2;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop) {
- if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
- parser.parseInteger(prop.value2))
- return failure();
- return success();
-}
-
-static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
- return parser.parseLSquare() || parser.parseInteger(value[0]) ||
- parser.parseComma() || parser.parseInteger(value[1]) ||
- parser.parseComma() || parser.parseInteger(value[2]) ||
- parser.parseRSquare();
-}
-
-static void printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
- ArrayRef<int64_t> value) {
- printer << '[' << value << ']';
-}
-
-static bool parseIntProperty(OpAsmParser &parser, int64_t &value) {
- return failed(parser.parseInteger(value));
-}
-
-static void printIntProperty(OpAsmPrinter &printer, Operation *op,
- int64_t value) {
- printer << value;
-}
-
-static bool parseSumProperty(OpAsmParser &parser, int64_t &second,
- int64_t first) {
- int64_t sum;
- auto loc = parser.getCurrentLocation();
- if (parser.parseInteger(second) || parser.parseEqual() ||
- parser.parseInteger(sum))
- return true;
- if (sum != second + first) {
- parser.emitError(loc, "Expected sum to equal first + second");
- return true;
- }
- return false;
-}
-
-static void printSumProperty(OpAsmPrinter &printer, Operation *op,
- int64_t second, int64_t first) {
- printer << second << " = " << (second + first);
-}
-
-//===----------------------------------------------------------------------===//
-// Tensor/Buffer Ops
-//===----------------------------------------------------------------------===//
-
-void ReadBufferOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- // The buffer operand is read.
- effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
- SideEffects::DefaultResource::get());
- // The buffer contents are dumped.
- effects.emplace_back(MemoryEffects::Write::get(),
- SideEffects::DefaultResource::get());
-}
-
-//===----------------------------------------------------------------------===//
-// Test Dataflow
-//===----------------------------------------------------------------------===//
-
-CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallAndStoreOp::getArgOperands() {
- return getCalleeOperands();
-}
-
-MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
- return getCalleeOperandsMutable();
-}
-
-CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
- return getForwardedOperands();
-}
-
-MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
- return getForwardedOperandsMutable();
-}
-
-void TestStoreWithARegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.isParent())
- regions.emplace_back(&getBody(), getBody().front().getArguments());
- else
- regions.emplace_back();
-}
-
-void TestStoreWithALoopRegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // Both the operation itself and the region may be branching into the body or
- // back into the operation itself. It is possible for the operation not to
- // enter the body.
- regions.emplace_back(
- RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
-}
-
-LogicalResult
-TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
- ::mlir::OperationState &state) {
- auto &prop = state.getOrAddProperties<Properties>();
- if (::mlir::failed(reader.readAttribute(prop.dims)))
- return ::mlir::failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- return success();
- }
- }
-
- if (::mlir::failed(reader.readAttribute(prop.modifier)))
- return ::mlir::failure();
- return ::mlir::success();
-}
-
-void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
- auto &prop = getProperties();
- writer.writeAttribute(prop.dims);
-
- auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to write.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- llvm::outs() << "downgrading op properties...\n";
- return;
- }
- }
- writer.writeAttribute(prop.modifier);
-}
-
-::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
- ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
- uint64_t value1, value2 = 0;
- if (failed(reader.readVarInt(value1)))
- return failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- bool needToParseAnotherInt = true;
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2))
- needToParseAnotherInt = false;
- }
- if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
- return failure();
-
- prop.value1 = value1;
- prop.value2 = value2;
- return success();
-}
-
-void TestOpWithVersionedProperties::writeToMlirBytecode(
- ::mlir::DialectBytecodeWriter &writer,
- const test::VersionedProperties &prop) {
- writer.writeVarInt(prop.value1);
- writer.writeVarInt(prop.value2);
-}
-
-#include "TestOpEnums.cpp.inc"
-#include "TestOpInterfaces.cpp.inc"
-#include "TestTypeInterfaces.cpp.inc"
-
-#define GET_OP_CLASSES
-#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index d5b2fbeafc4104..c05e15fc642a25 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -43,19 +43,18 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include <memory>
namespace mlir {
-class DLTIDialect;
class RewritePatternSet;
-} // namespace mlir
+} // end namespace mlir
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
-#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"
namespace test {
@@ -75,49 +74,8 @@ struct TestDialectVersion : public mlir::DialectVersion {
uint32_t minor_ = 0;
};
-// Define some classes to exercises the Properties feature.
-
-struct PropertiesWithCustomPrint {
- /// A shared_ptr to a const object is safe: it is equivalent to a value-based
- /// member. Here the label will be deallocated when the last operation
- /// refering to it is destroyed. However there is no pool-allocation: this is
- /// offloaded to the client.
- std::shared_ptr<const std::string> label;
- int value;
- bool operator==(const PropertiesWithCustomPrint &rhs) const {
- return value == rhs.value && *label == *rhs.label;
- }
-};
-class MyPropStruct {
-public:
- std::string content;
- // These three methods are invoked through the `MyStructProperty` wrapper
- // defined in TestOps.td
- mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
- static mlir::LogicalResult
- setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
- llvm::hash_code hash() const;
- bool operator==(const MyPropStruct &rhs) const {
- return content == rhs.content;
- }
-};
-struct VersionedProperties {
- // For the sake of testing, assume that this object was associated to version
- // 1.2 of the test dialect when having only one int value. In the current
- // version 2.0, the property has two values. We also assume that the class is
- // upgrade-able if value2 = 0.
- int value1;
- int value2;
- bool operator==(const VersionedProperties &rhs) const {
- return value1 == rhs.value1 && value2 == rhs.value2;
- }
-};
} // namespace test
-#define GET_OP_CLASSES
-#include "TestOps.h.inc"
-
namespace test {
// Op deliberately defined in C++ code rather than ODS to test that C++
@@ -138,6 +96,10 @@ class ManualCppOpWithFold
void registerTestDialect(::mlir::DialectRegistry ®istry);
void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns);
+void testSideEffectOpGetEffect(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::TestEffects::Effect>> &effects);
} // namespace test
#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..a3a8913d5964c6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
new file mode 100644
index 00000000000000..6e75dd39322810
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
@@ -0,0 +1,377 @@
+//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestFormatUtils.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperands(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
+ if (parser.parseOperand(operand))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseOperandList(varOperands) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+ Value operand, Value optOperand,
+ OperandRange varOperands) {
+ printer << operand;
+ if (optOperand)
+ printer << ", " << optOperand;
+ printer << " -> (" << varOperands << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+ Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(operandType))
+ return failure();
+ if (succeeded(parser.parseOptionalComma()))
+ if (parser.parseType(optOperandType))
+ return failure();
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+ Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " : " << operandType;
+ if (optOperandType)
+ printer << ", " << optOperandType;
+ printer << " -> (" << varOperandTypes << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveWithTypeRefs(
+ OpAsmParser &parser, Type operandType, Type optOperandType,
+ const SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseKeyword("type_refs_capture"))
+ return failure();
+
+ Type operandType2, optOperandType2;
+ SmallVector<Type, 1> varOperandTypes2;
+ if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+ varOperandTypes2))
+ return failure();
+
+ if (operandType != operandType2 || optOperandType != optOperandType2 ||
+ varOperandTypes != varOperandTypes2)
+ return failure();
+
+ return success();
+}
+
+void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+ Operation *op, Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " type_refs_capture ";
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperandsAndTypes(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
+ Type &operandType, Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+ parseCustomDirectiveResults(parser, operandType, optOperandType,
+ varOperandTypes))
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperandsAndTypes(
+ OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+ OperandRange varOperands, Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveRegions(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+ if (parser.parseRegion(region))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+ if (parser.parseRegion(*varRegion))
+ return failure();
+ varRegions.emplace_back(std::move(varRegion));
+ return success();
+}
+
+void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+ Region ®ion,
+ MutableArrayRef<Region> varRegions) {
+ printer.printRegion(region);
+ if (!varRegions.empty()) {
+ printer << ", ";
+ for (Region ®ion : varRegions)
+ printer.printRegion(region);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+ SmallVectorImpl<Block *> &varSuccessors) {
+ if (parser.parseSuccessor(successor))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ Block *varSuccessor;
+ if (parser.parseSuccessor(varSuccessor))
+ return failure();
+ varSuccessors.append(2, varSuccessor);
+ return success();
+}
+
+void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
+ Block *successor,
+ SuccessorRange varSuccessors) {
+ printer << successor;
+ if (!varSuccessors.empty())
+ printer << ", " << varSuccessors.front();
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
+ IntegerAttr &attr,
+ IntegerAttr &optAttr) {
+ if (parser.parseAttribute(attr))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseAttribute(optAttr))
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
+ Attribute attribute,
+ Attribute optAttribute) {
+ printer << attribute;
+ if (optAttribute)
+ printer << ", " << optAttribute;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
+ NamedAttrList &attrs) {
+ return parser.parseOptionalAttrDict(attrs);
+}
+
+void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+ DictionaryAttr attrs) {
+ printer.printOptionalAttrDict(attrs.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOptionalOperandRef(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ int64_t operandCount = 0;
+ if (parser.parseInteger(operandCount))
+ return failure();
+ bool expectedOptionalOperand = operandCount == 0;
+ return success(expectedOptionalOperand != !!optOperand);
+}
+
+void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
+ Operation *op,
+ Value optOperand) {
+ printer << (optOperand ? "1" : "0");
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomOptionalOperand(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ if (succeeded(parser.parseOptionalLParen())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand) || parser.parseRParen())
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
+ Value optOperand) {
+ if (optOperand)
+ printer << "(" << optOperand << ") ";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+ SmallVector<int64_t> caseValues;
+ while (succeeded(p.parseOptionalKeyword("case"))) {
+ int64_t value;
+ Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
+ if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+ return failure();
+ caseValues.push_back(value);
+ }
+ cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+ return success();
+}
+
+void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
+ for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+ p.printNewline();
+ p << "case " << value << ' ';
+ p.printRegion(*region, /*printEntryBlockArgs=*/false);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
+ return parser.parseLSquare() || parser.parseInteger(value[0]) ||
+ parser.parseComma() || parser.parseInteger(value[1]) ||
+ parser.parseComma() || parser.parseInteger(value[2]) ||
+ parser.parseRSquare();
+}
+
+void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
+ ArrayRef<int64_t> value) {
+ printer << '[' << value << ']';
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
+ return failed(parser.parseInteger(value));
+}
+
+void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t value) {
+ printer << value;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
+ int64_t first) {
+ int64_t sum;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseInteger(second) || parser.parseEqual() ||
+ parser.parseInteger(sum))
+ return true;
+ if (sum != second + first) {
+ parser.emitError(loc, "Expected sum to equal first + second");
+ return true;
+ }
+ return false;
+}
+
+void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t second, int64_t first) {
+ printer << second << " = " << (second + first);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
+ IntegerAttr &result) {
+ if (succeeded(p.parseOptionalKeyword("foo")))
+ return p.parseAttribute(result);
+ return {};
+}
+
+void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
+ IntegerAttr result) {
+ p << "foo ";
+ p.printAttribute(result);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
+ Attribute &attr) {
+ return parser.parseAttribute(attr, type.getValue());
+}
+
+void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
+ Attribute attr) {
+ printer.printAttributeWithoutType(attr);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
new file mode 100644
index 00000000000000..7e9cd834278e34
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
@@ -0,0 +1,211 @@
+//===- TestFormatUtils.h - MLIR Test Dialect Assembly Format Utilities ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTFORMATUTILS_H
+#define MLIR_TESTFORMATUTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+namespace test {
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperands(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands);
+
+void printCustomDirectiveOperands(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Value operand,
+ mlir::Value optOperand,
+ mlir::OperandRange varOperands);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult
+parseCustomDirectiveResults(mlir::OpAsmParser &parser, mlir::Type &operandType,
+ mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveResults(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveWithTypeRefs(
+ mlir::OpAsmParser &parser, mlir::Type operandType,
+ mlir::Type optOperandType,
+ const llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveWithTypeRefs(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands,
+ mlir::Type &operandType, mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value operand,
+ mlir::Value optOperand, mlir::OperandRange varOperands,
+ mlir::Type operandType, mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveRegions(
+ mlir::OpAsmParser &parser, mlir::Region ®ion,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &varRegions);
+
+void printCustomDirectiveRegions(
+ mlir::OpAsmPrinter &printer, mlir::Operation *, mlir::Region ®ion,
+ llvm::MutableArrayRef<mlir::Region> varRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveSuccessors(
+ mlir::OpAsmParser &parser, mlir::Block *&successor,
+ llvm::SmallVectorImpl<mlir::Block *> &varSuccessors);
+
+void printCustomDirectiveSuccessors(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Block *successor,
+ mlir::SuccessorRange varSuccessors);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttributes(mlir::OpAsmParser &parser,
+ mlir::IntegerAttr &attr,
+ mlir::IntegerAttr &optAttr);
+
+void printCustomDirectiveAttributes(mlir::OpAsmPrinter &printer,
+ mlir::Operation *,
+ mlir::Attribute attribute,
+ mlir::Attribute optAttribute);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttrDict(mlir::OpAsmParser &parser,
+ mlir::NamedAttrList &attrs);
+
+void printCustomDirectiveAttrDict(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::DictionaryAttr attrs);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOptionalOperandRef(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomDirectiveOptionalOperandRef(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomOptionalOperand(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomOptionalOperand(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseSwitchCases(
+ mlir::OpAsmParser &p, mlir::DenseI64ArrayAttr &cases,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &caseRegions);
+
+void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::DenseI64ArrayAttr cases,
+ mlir::RegionRange caseRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]);
+
+void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ llvm::ArrayRef<int64_t> value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool parseIntProperty(mlir::OpAsmParser &parser, int64_t &value);
+
+void printIntProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool parseSumProperty(mlir::OpAsmParser &parser, int64_t &second,
+ int64_t first);
+
+void printSumProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t second, int64_t first);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+mlir::OptionalParseResult parseOptionalCustomParser(mlir::AsmParser &p,
+ mlir::IntegerAttr &result);
+
+void printOptionalCustomParser(mlir::AsmPrinter &p, mlir::Operation *,
+ mlir::IntegerAttr result);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
+ mlir::TypeAttr type,
+ mlir::Attribute &attr);
+
+void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
+ mlir::TypeAttr type, mlir::Attribute attr);
+
+} // end namespace test
+
+#endif // MLIR_TESTFORMATUTILS_H
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
index 3673d62bea2c94..dc6413b25707e3 100644
--- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
index 64ec82ecb24ff8..14099bb4bb16ba 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
@@ -6,3 +6,5 @@ bool mlir::TestEffects::Effect::classof(
const mlir::SideEffects::Effect *effect) {
return isa<mlir::TestEffects::Concrete>(effect);
}
+
+#include "TestOpInterfaces.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.h b/mlir/test/lib/Dialect/Test/TestInterfaces.h
index 3239584a93326d..d58d1aafbe66c2 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.h
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.h
@@ -34,4 +34,6 @@ struct Concrete : public Effect::Base<Concrete> {};
} // namespace TestEffects
} // namespace mlir
+#include "TestOpInterfaces.h.inc"
+
#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
new file mode 100644
index 00000000000000..7263774ca158eb
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -0,0 +1,1161 @@
+//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "TestOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return SuccessorOperands(getTargetOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 1)
+ return SuccessorOperands(getFirstOperandsMutable());
+ return SuccessorOperands(getSecondOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestInternalBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(0, getSuccessOperandsMutable());
+ return SuccessorOperands(1, getErrorOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FoldToCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
+ using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FoldToCallOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
+ op.getCalleeAttr(), ValueRange());
+ return success();
+ }
+};
+} // namespace
+
+void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldToCallOpPattern>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// IsolatedRegionOp - test parsing passthrough operands
+//===----------------------------------------------------------------------===//
+
+ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse the input operand.
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(argInfo.ssaName) ||
+ parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
+}
+
+void IsolatedRegionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperand(getOperand());
+ p.shadowRegionArgs(getRegion(), getOperand());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SSACFGRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::SSACFG;
+}
+
+//===----------------------------------------------------------------------===//
+// GraphRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind GraphRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::Graph;
+}
+
+//===----------------------------------------------------------------------===//
+// AffineScopeOp
+//===----------------------------------------------------------------------===//
+
+ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
+}
+
+void AffineScopeOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestRemoveOpWithInnerOps
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestRemoveOpWithInnerOps
+ : public OpRewritePattern<TestOpWithRegionPattern> {
+ using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
+
+ void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
+
+ LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionPattern
+//===----------------------------------------------------------------------===//
+
+void TestOpWithRegionPattern::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<TestRemoveOpWithInnerOps>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
+ return getOperand();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpConstant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVariadicResultsAndFolder
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithVariadicResultsAndFolder::fold(
+ FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
+ for (Value input : this->getOperands()) {
+ results.push_back(input);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpInPlaceFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
+ // Exercise the fact that an operation created with createOrFold should be
+ // allowed to access its parent block.
+ assert(getOperation()->getBlock() &&
+ "expected that operation is not unlinked");
+
+ if (adaptor.getOp() && !getProperties().attr) {
+ // The folder adds "attr" if not present.
+ getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
+ return getResult();
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ inferredReturnTypes.assign({operands[0].getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = operands.front().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto rank = cast<RankedTensorType>(operand.getType()).getRank();
+ auto currShape = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
+ return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
+ currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapePerDimInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto tensorType = cast<RankedTensorType>(operand.getType());
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, tensorType.getRank()),
+ [&](int64_t dim) -> OpFoldResult {
+ return tensorType.isDynamicDim(dim)
+ ? static_cast<OpFoldResult>(
+ builder.createOrFold<tensor::DimOp>(loc, operand,
+ dim))
+ : static_cast<OpFoldResult>(
+ builder.getIndexAttr(tensorType.getDimSize(dim)));
+ }));
+ shapes.emplace_back(std::move(currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SideEffectOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A test resource for side effects.
+struct TestResource : public SideEffects::Resource::Base<TestResource> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
+
+ StringRef getName() final { return "<Test>"; }
+};
+} // namespace
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Check for an effects attribute on the op instance.
+ ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
+ if (!effectsAttr)
+ return;
+
+ // If there is one, it is an array of dictionary attributes that hold
+ // information on the effects of this operation.
+ for (Attribute element : effectsAttr) {
+ DictionaryAttr effectElement = cast<DictionaryAttr>(element);
+
+ // Get the specific memory effect.
+ MemoryEffects::Effect *effect =
+ StringSwitch<MemoryEffects::Effect *>(
+ cast<StringAttr>(effectElement.get("effect")).getValue())
+ .Case("allocate", MemoryEffects::Allocate::get())
+ .Case("free", MemoryEffects::Free::get())
+ .Case("read", MemoryEffects::Read::get())
+ .Case("write", MemoryEffects::Write::get());
+
+ // Check for a non-default resource to use.
+ SideEffects::Resource *resource = SideEffects::DefaultResource::get();
+ if (effectElement.get("test_resource"))
+ resource = TestResource::get();
+
+ // Check for a result to affect.
+ if (effectElement.get("on_result"))
+ effects.emplace_back(effect, getResult(), resource);
+ else if (Attribute ref = effectElement.get("on_reference"))
+ effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
+ else
+ effects.emplace_back(effect, resource);
+ }
+}
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<TestEffects::EffectInstance> &effects) {
+ testSideEffectOpGetEffect(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttrPrettyNameOp
+//===----------------------------------------------------------------------===//
+
+// This op has fancy handling of its SSA result name.
+ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Add the result types.
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
+ result.addTypes(parser.getBuilder().getIntegerType(32));
+
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // If the attribute dictionary contains no 'names' attribute, infer it from
+ // the SSA name (if specified).
+ bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
+ return attr.getName() == "names";
+ });
+
+ // If there was no name specified, check to see if there was a useful name
+ // specified in the asm file.
+ if (hadNames || parser.getNumResults() == 0)
+ return success();
+
+ SmallVector<StringRef, 4> names;
+ auto *context = result.getContext();
+
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
+ auto resultName = parser.getResultName(i);
+ StringRef nameStr;
+ if (!resultName.first.empty() && !isdigit(resultName.first[0]))
+ nameStr = resultName.first;
+
+ names.push_back(nameStr);
+ }
+
+ auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
+ result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
+ return success();
+}
+
+void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
+ // Note that we only need to print the "name" attribute if the asmprinter
+ // result name disagrees with it. This can happen in strange cases, e.g.
+ // when there are conflicts.
+ bool namesDisagree = getNames().size() != getNumResults();
+
+ SmallString<32> resultNameStr;
+ for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
+ resultNameStr.clear();
+ llvm::raw_svector_ostream tmpStream(resultNameStr);
+ p.printOperand(getResult(i), tmpStream);
+
+ auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
+ if (!expectedName ||
+ tmpStream.str().drop_front() != expectedName.getValue()) {
+ namesDisagree = true;
+ }
+ }
+
+ if (namesDisagree)
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+ else
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
+}
+
+// We set the SSA name in the asm syntax to the contents of the name
+// attribute.
+void StringAttrPrettyNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+
+ auto value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.getValue().empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomResultsNameOp
+//===----------------------------------------------------------------------===//
+
+void CustomResultsNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ ArrayAttr value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// ResultTypeWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResultTypeWithTraitOp::verify() {
+ if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+ return success();
+ return emitError("result type should have trait 'TestTypeTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AttrWithTraitOp::verify() {
+ if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
+ return success();
+ return emitError("'attr' attribute should have trait 'TestAttrTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// RegionIfOp
+//===----------------------------------------------------------------------===//
+
+void RegionIfOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printOperands(getOperands());
+ p << ": " << getOperandTypes();
+ p.printArrowTypeList(getResultTypes());
+ p << " then ";
+ p.printRegion(getThenRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " else ";
+ p.printRegion(getElseRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " join ";
+ p.printRegion(getJoinRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+}
+
+ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
+ SmallVector<Type, 2> operandTypes;
+
+ result.regions.reserve(3);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+ Region *joinRegion = result.addRegion();
+
+ // Parse operand, type and arrow type lists.
+ if (parser.parseOperandList(operandInfos) ||
+ parser.parseColonTypeList(operandTypes) ||
+ parser.parseArrowTypeList(result.types))
+ return failure();
+
+ // Parse all attached regions.
+ if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
+ parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
+ parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
+ return failure();
+
+ return parser.resolveOperands(operandInfos, operandTypes,
+ parser.getCurrentLocation(), result.operands);
+}
+
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+ "invalid region index");
+ return getOperands();
+}
+
+void RegionIfOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // We always branch to the join region.
+ if (!point.isParent()) {
+ if (point != getJoinRegion())
+ regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
+ else
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // The then and else regions are the entry regions of this op.
+ regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
+ regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
+}
+
+void RegionIfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ // Each region is invoked at most once.
+ invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
+}
+
+//===----------------------------------------------------------------------===//
+// AnyCondOp
+//===----------------------------------------------------------------------===//
+
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The parent op branches into the only region, and the region branches back
+ // to the parent op.
+ if (point.isParent())
+ regions.emplace_back(&getRegion());
+ else
+ regions.emplace_back(getResults());
+}
+
+void AnyCondOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ invocationBounds.emplace_back(1, 1);
+}
+
+//===----------------------------------------------------------------------===//
+// SingleBlockImplicitTerminatorOp
+//===----------------------------------------------------------------------===//
+
+/// Testing the correctness of some traits.
+static_assert(
+ llvm::is_detected<OpTrait::has_implicit_terminator_t,
+ SingleBlockImplicitTerminatorOp>::value,
+ "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
+static_assert(OpTrait::hasSingleBlockImplicitTerminator<
+ SingleBlockImplicitTerminatorOp>::value,
+ "hasSingleBlockImplicitTerminator does not match "
+ "SingleBlockImplicitTerminatorOp");
+
+//===----------------------------------------------------------------------===//
+// SingleNoTerminatorCustomAsmOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
+ OperationState &state) {
+ Region *body = state.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ return success();
+}
+
+void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
+ printer.printRegion(
+ getRegion(), /*printEntryBlockArgs=*/false,
+ // This op has a single block without terminators. But explicitly mark
+ // as not printing block terminators for testing.
+ /*printBlockTerminators=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestVerifiersOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestVerifiersOp::verify() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ Operation *definingOp = getInput().getDefiningOp();
+ if (definingOp && failed(mlir::verify(definingOp)))
+ return emitOpError("operand hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of verifier");
+
+ return success();
+}
+
+LogicalResult TestVerifiersOp::verifyRegions() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ for (Block &block : getRegion())
+ for (Operation &op : block)
+ if (failed(mlir::verify(&op)))
+ return emitOpError("nested op hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of region verifier");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsOp
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsRegionOp
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse the input argument
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (failed(parser.parseArgument(argInfo)))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs());
+ p << ' ';
+ p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+ /*omitType=*/true);
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ Value arg = getRegion().getArgument(0);
+ setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestIncrementOp
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ APInt one(range.umin().getBitWidth(), 1);
+ setResultRanges(getResult(),
+ {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+ range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+//===----------------------------------------------------------------------===//
+// TestReflectBoundsOp
+
+void TestReflectBoundsOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ MLIRContext *ctx = getContext();
+ Builder b(ctx);
+ setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+ setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+ setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+ setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+ setResultRanges(getResult(), range);
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ConversionFuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+ if (getType() == "EQ")
+ return mlir::presburger::BoundType::EQ;
+ if (getType() == "LB")
+ return mlir::presburger::BoundType::LB;
+ if (getType() == "UB")
+ return mlir::presburger::BoundType::UB;
+ llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+ if (isa<ShapedType>(getVar().getType())) {
+ if (!getDim().has_value())
+ return emitOpError("expected 'dim' attribute for shaped type variable");
+ } else if (getVar().getType().isIndex()) {
+ if (getDim().has_value())
+ return emitOpError("unexpected 'dim' attribute for index variable");
+ } else {
+ return emitOpError("expected index-typed variable or shape type variable");
+ }
+ if (getConstant() && getScalable())
+ return emitOpError("'scalable' and 'constant' are mutually exlusive");
+ if (getScalable() != getVscaleMin().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ if (getScalable() != getVscaleMax().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ return success();
+}
+
+ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
+//===----------------------------------------------------------------------===//
+// CompareOp
+//===----------------------------------------------------------------------===//
+
+ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+ if (getCmp() == "EQ")
+ return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+ if (getCmp() == "LT")
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ if (getCmp() == "LE")
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (getCmp() == "GT")
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (getCmp() == "GE")
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ llvm_unreachable("invalid comparison operator");
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpFoldWithFoldAdaptor
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
+ int64_t sum = 0;
+ if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
+ sum += value.getValue().getSExtValue();
+
+ for (Attribute attr : adaptor.getVariadic())
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 2 * value.getValue().getSExtValue();
+
+ for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
+ for (Attribute attr : attrs)
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 3 * value.getValue().getSExtValue();
+
+ sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
+
+ return IntegerAttr::get(getType(), sum);
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location,
+ OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getX().getType() != adaptor.getY().getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ adaptor.getX().getType(), " vs ",
+ adaptor.getY().getType());
+ }
+ inferredReturnTypes.assign({adaptor.getX().getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithRefineTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+// TODO: We should be able to only define either inferReturnType or
+// refineReturnType, currently only refineReturnType can be omitted.
+LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ returnTypes.clear();
+ return OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ context, location, operands, attributes, properties, regions,
+ returnTypes);
+}
+
+LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ // TODO: Add helper to make this more concise to write.
+ if (returnTypes.empty())
+ returnTypes.resize(1, nullptr);
+ if (returnTypes[0] && returnTypes[0] != operands[0].getType())
+ return emitOptionalError(location,
+ "required first operand and result to match");
+ returnTypes[0] = operands[0].getType();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = adaptor.getOperand1().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithPropertiesAndInferredType
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
+ MLIRContext *context, std::optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+
+ Adaptor adaptor(operands, attributes, properties, regions);
+ inferredReturnTypes.push_back(IntegerType::get(
+ context, adaptor.getLhs() + adaptor.getProperties().rhs));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockOp
+//===----------------------------------------------------------------------===//
+
+void LoopBlockOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ regions.emplace_back(&getBody(), getBody().getArguments());
+ if (point.isParent())
+ return;
+
+ regions.emplace_back((*this)->getResults());
+}
+
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
+ return MutableOperandRange(getInitMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockTerminatorOp
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
+ return getExitArgMutable();
+ return getNextIterArgMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// SwitchWithNoBreakOp
+//===----------------------------------------------------------------------===//
+
+void TestNoTerminatorOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
+ // Just a simple fold for testing purposes that reads an operands constant
+ // value and returns it.
+ if (!attributes.empty())
+ return attributes.front();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Tensor/Buffer Ops
+//===----------------------------------------------------------------------===//
+
+void ReadBufferOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // The buffer operand is read.
+ effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
+ SideEffects::DefaultResource::get());
+ // The buffer contents are dumped.
+ effects.emplace_back(MemoryEffects::Write::get(),
+ SideEffects::DefaultResource::get());
+}
+
+//===----------------------------------------------------------------------===//
+// Test Dataflow
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestCallAndStoreOp
+
+CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallAndStoreOp::getArgOperands() {
+ return getCalleeOperands();
+}
+
+MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
+ return getCalleeOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOnDeviceOp
+
+CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
+ return getForwardedOperands();
+}
+
+MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
+ return getForwardedOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithARegion
+
+void TestStoreWithARegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody(), getBody().front().getArguments());
+ else
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithALoopRegion
+
+void TestStoreWithALoopRegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for the operation not to
+ // enter the body.
+ regions.emplace_back(
+ RegionSuccessor(&getBody(), getBody().front().getArguments()));
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestVersionedOpA
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ auto &prop = state.getOrAddProperties<Properties>();
+ if (mlir::failed(reader.readAttribute(prop.dims)))
+ return mlir::failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ return success();
+ }
+ }
+
+ if (mlir::failed(reader.readAttribute(prop.modifier)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
+ auto &prop = getProperties();
+ writer.writeAttribute(prop.dims);
+
+ auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to write.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ llvm::outs() << "downgrading op properties...\n";
+ return;
+ }
+ }
+ writer.writeAttribute(prop.modifier);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVersionedProperties
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
+ mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
+ uint64_t value1, value2 = 0;
+ if (failed(reader.readVarInt(value1)))
+ return failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ bool needToParseAnotherInt = true;
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2))
+ needToParseAnotherInt = false;
+ }
+ if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
+ return failure();
+
+ prop.value1 = value1;
+ prop.value2 = value2;
+ return success();
+}
+
+void TestOpWithVersionedProperties::writeToMlirBytecode(
+ mlir::DialectBytecodeWriter &writer,
+ const test::VersionedProperties &prop) {
+ writer.writeVarInt(prop.value1);
+ writer.writeVarInt(prop.value2);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp
new file mode 100644
index 00000000000000..ce7e476be74e65
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.cpp
@@ -0,0 +1,18 @@
+//===- TestOps.cpp - MLIR Test Dialect Operations ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOps.h"
+#include "TestDialect.h"
+#include "TestFormatUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+using namespace mlir;
+using namespace test;
+
+#define GET_OP_CLASSES
+#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
new file mode 100644
index 00000000000000..f9925855bb9db6
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -0,0 +1,149 @@
+//===- TestOps.h - MLIR Test Dialect Operations ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTOPS_H
+#define MLIR_TESTOPS_H
+
+#include "TestAttributes.h"
+#include "TestInterfaces.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/DLTI/Traits.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace test {
+class TestDialect;
+
+//===----------------------------------------------------------------------===//
+// TestResource
+//===----------------------------------------------------------------------===//
+
+/// A test resource for side effects.
+struct TestResource : public mlir::SideEffects::Resource::Base<TestResource> {
+ llvm::StringRef getName() final { return "<Test>"; }
+};
+
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+struct PropertiesWithCustomPrint {
+ /// A shared_ptr to a const object is safe: it is equivalent to a value-based
+ /// member. Here the label will be deallocated when the last operation
+ /// refering to it is destroyed. However there is no pool-allocation: this is
+ /// offloaded to the client.
+ std::shared_ptr<const std::string> label;
+ int value;
+ bool operator==(const PropertiesWithCustomPrint &rhs) const {
+ return value == rhs.value && *label == *rhs.label;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ PropertiesWithCustomPrint &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr
+getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop);
+llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop);
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
+class MyPropStruct {
+public:
+ std::string content;
+ // These three methods are invoked through the `MyStructProperty` wrapper
+ // defined in TestOps.td
+ mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
+ static mlir::LogicalResult
+ setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+ llvm::hash_code hash() const;
+ bool operator==(const MyPropStruct &rhs) const {
+ return content == rhs.content;
+ }
+};
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ MyPropStruct &prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ MyPropStruct &prop);
+
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+struct VersionedProperties {
+ // For the sake of testing, assume that this object was associated to version
+ // 1.2 of the test dialect when having only one int value. In the current
+ // version 2.0, the property has two values. We also assume that the class is
+ // upgrade-able if value2 = 0.
+ int value1;
+ int value2;
+ bool operator==(const VersionedProperties &rhs) const {
+ return value1 == rhs.value1 && value2 == rhs.value2;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ VersionedProperties &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const VersionedProperties &prop);
+llvm::hash_code computeHash(const VersionedProperties &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const VersionedProperties &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ VersionedProperties &prop);
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ llvm::MutableArrayRef<int64_t> prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ llvm::ArrayRef<int64_t> prop);
+
+} // namespace test
+
+#define GET_OP_CLASSES
+#include "TestOps.h.inc"
+
+#endif // MLIR_TESTOPS_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index 84e6a43655cacd..c376d6c73c6452 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -8,6 +8,7 @@
#include "TestOpsSyntax.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Base64.h"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..0c1731ba5f07c8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index fa093cafcb0dc3..57e7d658fb501f 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index d9b67ef95ace83..031e1062dac76d 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 7a195eb25a3ba1..1593b6d7d7534b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -139,6 +139,7 @@ static void printBarString(AsmPrinter &printer, StringRef foo) {
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
+#include "TestTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index b1b5921d8faddd..da5604944d5a3b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -31,11 +31,11 @@ class TestAttrWithFormatAttr;
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
struct FieldInfo {
- ::llvm::StringRef name;
- ::mlir::Type type;
+ llvm::StringRef name;
+ mlir::Type type;
// Custom allocation called from generated constructor code
- FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+ FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const {
return FieldInfo{alloc.copyInto(name), type};
}
};
diff --git a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index e668224d343234..4894ad5294990a 100644
--- a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
index 7b18f219b915f4..b742b316c77126 100644
--- a/mlir/test/lib/IR/TestClone.cpp
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index 09ad1363228243..8e13dd9751398c 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 0e1368f2e0ecaf..b470b15c533b57 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp
index 2bd63a48f77d1d..c6bce111d3ea7f 100644
--- a/mlir/test/lib/IR/TestTypes.cpp
+++ b/mlir/test/lib/IR/TestTypes.cpp
@@ -8,6 +8,7 @@
#include "TestTypes.h"
#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
index 00148df26e3512..4556671df0ba0b 100644
--- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 477b75916f80c8..2762e254903245 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp
index 9821179d05e891..223cc78dd1e21d 100644
--- a/mlir/test/lib/Transforms/TestInlining.cpp
+++ b/mlir/test/lib/Transforms/TestInlining.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
index 61e1fbcf3feaf3..82fa6cdb68d23c 100644
--- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
+++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index 66ce53bbbadec9..0a5fa8d3c475c3 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestOpsSyntax.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp
index 83627975006ee8..b88009d1e3c361 100644
--- a/mlir/unittests/IR/IRMapping.cpp
+++ b/mlir/unittests/IR/IRMapping.cpp
@@ -11,6 +11,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 58049a9969e3ab..b6066dd5685dc6 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -19,6 +19,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
#include "mlir/IR/OwningOpRef.h"
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 5ab4d9a106231a..42196b003e7dad 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -15,6 +15,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9d75615b39c0c1..f94dc784458077 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/OperationSupport.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 30b72618e45f0b..75d5228c82d99b 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -10,6 +10,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 52347dcabe0381..c83ac9088114ce 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
More information about the Mlir-commits
mailing list