[Mlir-commits] [mlir] ee09087 - Change the printing/parsing behavior for Attributes used in declarative assembly format
Mehdi Amini
llvmlistbot at llvm.org
Tue Dec 7 18:02:46 PST 2021
Author: Mehdi Amini
Date: 2021-12-08T02:02:37Z
New Revision: ee0908703d2917d7310b71c5078fef44e8270317
URL: https://github.com/llvm/llvm-project/commit/ee0908703d2917d7310b71c5078fef44e8270317
DIFF: https://github.com/llvm/llvm-project/commit/ee0908703d2917d7310b71c5078fef44e8270317.diff
LOG: Change the printing/parsing behavior for Attributes used in declarative assembly format
The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part.
Differential Revision: https://reviews.llvm.org/D113873
Added:
mlir/test/lib/Dialect/Test/TestDialect.td
Modified:
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/include/mlir/IR/DialectImplementation.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Parser/AsmParserImpl.h
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/Dialect/Async/runtime.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/test/mlir-tblgen/testdialect-typedefs.mlir
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index b3d41401b1394..b0cfd5856e06e 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
vector operations, including a scalable vector type and intrinsics for
some Arm SVE instructions.
}];
+ let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
"Type":$elementType
);
- let printer = [{
- $_printer << "<";
- for (int64_t dim : getShape())
- $_printer << dim << 'x';
- $_printer << getElementType() << '>';
- }];
-
- let parser = [{
- VectorType vector;
- if ($_parser.parseType(vector))
- return Type();
- return get($_ctxt, vector.getShape(), vector.getElementType());
- }];
-
let extraClassDeclaration = [{
bool hasStaticShape() const {
return llvm::none_of(getShape(), ShapedType::isDynamic);
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1d0c4887e25e5..70afd029c94f9 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -64,19 +64,19 @@ struct FieldParser<
AttributeT>> {
static FailureOr<AttributeT> parse(AsmParser &parser) {
AttributeT value;
- if (parser.parseAttribute(value))
+ if (parser.parseCustomAttributeWithFallback(value))
return failure();
return value;
}
};
-/// Parse a type.
+/// Parse an attribute.
template <typename TypeT>
struct FieldParser<
TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
static FailureOr<TypeT> parse(AsmParser &parser) {
TypeT value;
- if (parser.parseType(value))
+ if (parser.parseCustomTypeWithFallback(value))
return failure();
return value;
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 26212d397575e..7e2326ab395ea 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+ // Make it possible to use such type as parameters for other types.
+ string cppType = dialect.cppNamespace # "::" # cppClassName;
+
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index dab6e106f9512..609f6be55a2c9 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -50,6 +50,36 @@ class AsmPrinter {
virtual void printType(Type type);
virtual void printAttribute(Attribute attr);
+ /// Trait to check if `AttrType` provides a `print` method.
+ template <typename AttrOrType>
+ using has_print_method =
+ decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
+ template <typename AttrOrType>
+ using detect_has_print_method =
+ llvm::is_detected<has_print_method, AttrOrType>;
+
+ /// Print the provided attribute in the context of an operation custom
+ /// printer/parser: this will invoke directly the print method on the
+ /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
+ template <typename AttrOrType,
+ std::enable_if_t<detect_has_print_method<AttrOrType>::value>
+ *sfinae = nullptr>
+ void printStrippedAttrOrType(AttrOrType attrOrType) {
+ if (succeeded(printAlias(attrOrType)))
+ return;
+ attrOrType.print(*this);
+ }
+
+ /// SFINAE for printing the provided attribute in the context of an operation
+ /// custom printer in the case where the attribute does not define a print
+ /// method.
+ template <typename AttrOrType,
+ std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
+ *sfinae = nullptr>
+ void printStrippedAttrOrType(AttrOrType attrOrType) {
+ *this << attrOrType;
+ }
+
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
@@ -102,6 +132,14 @@ class AsmPrinter {
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
+ /// Print the alias for the given attribute, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printAlias(Attribute attr);
+
+ /// Print the alias for the given type, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printAlias(Type type);
+
/// The internal implementation of the printer.
Impl *impl;
};
@@ -608,6 +646,13 @@ class AsmParser {
/// Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
+ /// Parse a custom attribute with the provided callback, unless the next
+ /// token is `#`, in which case the generic parser is invoked.
+ virtual ParseResult parseCustomAttributeWithFallback(
+ Attribute &result, Type type,
+ function_ref<ParseResult(Attribute &result, Type type)>
+ parseAttribute) = 0;
+
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type = {}) {
@@ -639,9 +684,9 @@ class AsmParser {
return parseAttribute(result, Type(), attrName, attrs);
}
- /// Parse an arbitrary attribute of a given type and return it in result. This
- /// also adds the attribute to the specified attribute list with the specified
- /// name.
+ /// Parse an arbitrary attribute of a given type and populate it in `result`.
+ /// This also adds the attribute to the specified attribute list with the
+ /// specified name.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
NamedAttrList &attrs) {
@@ -661,6 +706,82 @@ class AsmParser {
return success();
}
+ /// Trait to check if `AttrType` provides a `parse` method.
+ template <typename AttrType>
+ using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
+ std::declval<Type>()));
+ template <typename AttrType>
+ using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
+
+ /// Parse a custom attribute of a given type unless the next token is `#`, in
+ /// which case the generic parser is invoked. The parsed attribute is
+ /// populated in `result` and also added to the specified attribute list with
+ /// the specified name.
+ template <typename AttrType>
+ std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+ parseCustomAttributeWithFallback(AttrType &result, Type type,
+ StringRef attrName, NamedAttrList &attrs) {
+ llvm::SMLoc loc = getCurrentLocation();
+
+ // Parse any kind of attribute.
+ Attribute attr;
+ if (parseCustomAttributeWithFallback(
+ attr, type, [&](Attribute &result, Type type) -> ParseResult {
+ result = AttrType::parse(*this, type);
+ if (!result)
+ return failure();
+ return success();
+ }))
+ return failure();
+
+ // Check for the right kind of attribute.
+ result = attr.dyn_cast<AttrType>();
+ if (!result)
+ return emitError(loc, "invalid kind of attribute specified");
+
+ attrs.append(attrName, result);
+ return success();
+ }
+
+ /// SFINAE parsing method for Attribute that don't implement a parse method.
+ template <typename AttrType>
+ std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+ parseCustomAttributeWithFallback(AttrType &result, Type type,
+ StringRef attrName, NamedAttrList &attrs) {
+ return parseAttribute(result, type, attrName, attrs);
+ }
+
+ /// Parse a custom attribute of a given type unless the next token is `#`, in
+ /// which case the generic parser is invoked. The parsed attribute is
+ /// populated in `result`.
+ template <typename AttrType>
+ std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+ parseCustomAttributeWithFallback(AttrType &result) {
+ llvm::SMLoc loc = getCurrentLocation();
+
+ // Parse any kind of attribute.
+ Attribute attr;
+ if (parseCustomAttributeWithFallback(
+ attr, {}, [&](Attribute &result, Type type) -> ParseResult {
+ result = AttrType::parse(*this, type);
+ return success(!!result);
+ }))
+ return failure();
+
+ // Check for the right kind of attribute.
+ result = attr.dyn_cast<AttrType>();
+ if (!result)
+ return emitError(loc, "invalid kind of attribute specified");
+ return success();
+ }
+
+ /// SFINAE parsing method for Attribute that don't implement a parse method.
+ template <typename AttrType>
+ std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+ parseCustomAttributeWithFallback(AttrType &result) {
+ return parseAttribute(result);
+ }
+
/// Parse an arbitrary optional attribute of a given type and return it in
/// result.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
@@ -740,6 +861,11 @@ class AsmParser {
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
+ /// Parse a custom type with the provided callback, unless the next
+ /// token is `#`, in which case the generic parser is invoked.
+ virtual ParseResult parseCustomTypeWithFallback(
+ Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
+
/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
@@ -753,7 +879,7 @@ class AsmParser {
if (parseType(type))
return failure();
- // Check for the right kind of attribute.
+ // Check for the right kind of type.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of type specified");
@@ -761,6 +887,44 @@ class AsmParser {
return success();
}
+ /// Trait to check if `TypeT` provides a `parse` method.
+ template <typename TypeT>
+ using type_has_parse_method =
+ decltype(TypeT::parse(std::declval<AsmParser &>()));
+ template <typename TypeT>
+ using detect_type_has_parse_method =
+ llvm::is_detected<type_has_parse_method, TypeT>;
+
+ /// Parse a custom Type of a given type unless the next token is `#`, in
+ /// which case the generic parser is invoked. The parsed Type is
+ /// populated in `result`.
+ template <typename TypeT>
+ std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
+ parseCustomTypeWithFallback(TypeT &result) {
+ llvm::SMLoc loc = getCurrentLocation();
+
+ // Parse any kind of Type.
+ Type type;
+ if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
+ result = TypeT::parse(*this);
+ return success(!!result);
+ }))
+ return failure();
+
+ // Check for the right kind of Type.
+ result = type.dyn_cast<TypeT>();
+ if (!result)
+ return emitError(loc, "invalid kind of Type specified");
+ return success();
+ }
+
+ /// SFINAE parsing method for Type that don't implement a parse method.
+ template <typename TypeT>
+ std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
+ parseCustomTypeWithFallback(TypeT &result) {
+ return parseType(result);
+ }
+
/// Parse a type list.
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
do {
@@ -792,7 +956,7 @@ class AsmParser {
if (parseColonType(type))
return failure();
- // Check for the right kind of attribute.
+ // Check for the right kind of type.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 9f6632164b04e..482f0c351449b 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
// ScalableVectorType
//===----------------------------------------------------------------------===//
-Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
- llvm::SMLoc typeLoc = parser.getCurrentLocation();
- {
- Type genType;
- auto parseResult = generatedTypeParser(parser, "vector", genType);
- if (parseResult.hasValue())
- return genType;
- }
- parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
- return Type();
+void ScalableVectorType::print(AsmPrinter &printer) const {
+ printer << "<";
+ for (int64_t dim : getShape())
+ printer << dim << 'x';
+ printer << getElementType() << '>';
}
-void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
- if (failed(generatedTypePrinter(type, os)))
- llvm_unreachable("unexpected 'arm_sve' type kind");
+Type ScalableVectorType::parse(AsmParser &parser) {
+ SmallVector<int64_t> dims;
+ Type eltType;
+ if (parser.parseLess() ||
+ parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
+ parser.parseType(eltType) || parser.parseGreater())
+ return {};
+ return ScalableVectorType::get(eltType.getContext(), dims, eltType);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1b18b19df7e82..bbe7b7413f2cd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
};
void CombiningKindAttr::print(AsmPrinter &printer) const {
- printer << "kind<";
+ printer << "<";
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
return bitEnumContains(this->getKind(), kind);
});
@@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
void VectorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
- if (auto ck = attr.dyn_cast<CombiningKindAttr>())
+ if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
+ os << "kind";
ck.print(os);
- else
- llvm_unreachable("Unknown attribute type");
+ return;
+ }
+ llvm_unreachable("Unknown attribute type");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6d50838cc9ead..61364758c641c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1188,7 +1188,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
/// Ex:
/// ```
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
-/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
+/// %1 = vector.multi_reduction add, %0 [1]
/// : vector<8x32x16xf32> to vector<8x16xf32>
/// ```
/// Gets converted to:
@@ -1198,7 +1198,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
-/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
+/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct MultiReduceToContract
@@ -1247,7 +1247,7 @@ struct MultiReduceToContract
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
-/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
+/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
@@ -1257,7 +1257,7 @@ struct MultiReduceToContract
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
-/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractTranspose
@@ -1304,7 +1304,7 @@ struct CombineContractTranspose
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
-/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
+/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
@@ -1314,7 +1314,7 @@ struct CombineContractTranspose
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
-/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractBroadcast
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 454a42e9f40b7..d3c239d9489f3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -474,6 +474,14 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
void printAttributeWithoutType(Attribute attr) override {
printAttribute(attr);
}
+ LogicalResult printAlias(Attribute attr) override {
+ initializer.visit(attr);
+ return success();
+ }
+ LogicalResult printAlias(Type type) override {
+ initializer.visit(type);
+ return success();
+ }
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
@@ -1252,8 +1260,16 @@ class AsmPrinter::Impl {
void printAttribute(Attribute attr,
AttrTypeElision typeElision = AttrTypeElision::Never);
+ /// Print the alias for the given attribute, return failure if no alias could
+ /// be printed.
+ LogicalResult printAlias(Attribute attr);
+
void printType(Type type);
+ /// Print the alias for the given type, return failure if no alias could
+ /// be printed.
+ LogicalResult printAlias(Type type);
+
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -1594,6 +1610,14 @@ static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"_", "0xDEADBEEF">)";
}
+LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
+ return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
+}
+
+LogicalResult AsmPrinter::Impl::printAlias(Type type) {
+ return success(state && succeeded(state->getAliasState().getAlias(type, os)));
+}
+
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@@ -1602,7 +1626,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
}
// Try to print an alias for this attribute.
- if (state && succeeded(state->getAliasState().getAlias(attr, os)))
+ if (succeeded(printAlias(attr)))
return;
if (!isa<BuiltinDialect>(attr.getDialect()))
@@ -2104,6 +2128,16 @@ void AsmPrinter::printAttribute(Attribute attr) {
impl->printAttribute(attr);
}
+LogicalResult AsmPrinter::printAlias(Attribute attr) {
+ assert(impl && "expected AsmPrinter::printAlias to be overriden");
+ return impl->printAlias(attr);
+}
+
+LogicalResult AsmPrinter::printAlias(Type type) {
+ assert(impl && "expected AsmPrinter::printAlias to be overriden");
+ return impl->printAlias(type);
+}
+
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 211687166b501..38c7e165f343f 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -374,6 +374,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
//===----------------------------------------------------------------------===//
// BoolAttr
+//===----------------------------------------------------------------------===//
bool BoolAttr::getValue() const {
auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 33ed6b60932d4..10c38a86314fa 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/BitVector.h"
@@ -633,7 +634,7 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
return true;
// Allow custom dialect attributes.
- if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
+ if (!isa<BuiltinDialect>(memorySpace.getDialect()))
return true;
return false;
diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
index 70039c2736a42..e8c70b7df965c 100644
--- a/mlir/lib/Parser/AsmParserImpl.h
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -343,6 +343,29 @@ class AsmParserImpl : public BaseT {
return success(static_cast<bool>(result));
}
+ /// Parse a custom attribute with the provided callback, unless the next
+ /// token is `#`, in which case the generic parser is invoked.
+ ParseResult parseCustomAttributeWithFallback(
+ Attribute &result, Type type,
+ function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
+ override {
+ if (parser.getToken().isNot(Token::hash_identifier))
+ return parseAttribute(result, type);
+ result = parser.parseAttribute(type);
+ return success(static_cast<bool>(result));
+ }
+
+ /// Parse a custom attribute with the provided callback, unless the next
+ /// token is `#`, in which case the generic parser is invoked.
+ ParseResult parseCustomTypeWithFallback(
+ Type &result,
+ function_ref<ParseResult(Type &result)> parseType) override {
+ if (parser.getToken().isNot(Token::exclamation_identifier))
+ return parseType(result);
+ result = parser.parseType();
+ return success(static_cast<bool>(result));
+ }
+
OptionalParseResult parseOptionalAttribute(Attribute &result,
Type type) override {
return parser.parseOptionalAttribute(result, type);
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0c1da2741eb8b..6f4247f30625c 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -3,7 +3,7 @@
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
- // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+ // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@@ -12,7 +12,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
- // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+ // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@@ -21,7 +21,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
- // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+ // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@@ -30,7 +30,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
- // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+ // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 9feb788402257..c570c584cbe67 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -16,7 +16,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
%0 = arith.addf %arg0, %arg0 : f32
// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
%1 = async.runtime.create: !async.value<f32>
-// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : <f32>
async.runtime.store %0, %1: !async.value<f32>
// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.set_available %1: !async.value<f32>
@@ -32,9 +32,9 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
-// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
// CHECK: %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32
-// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
@@ -84,8 +84,8 @@ func @simple_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
// CHECK: ^[[BRANCH_VALUE_OK]]:
-// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
-// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
+// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
@@ -133,7 +133,7 @@ func @double_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
-// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
+// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
// CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]]
@@ -150,8 +150,8 @@ func @double_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
-// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
-// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value<f32>
+// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
+// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index c024957dd7664..34532e56db7b7 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -245,7 +245,7 @@ func @execute_and_return_f32() -> f32 {
}
// CHECK: async.runtime.await %[[RET]]#1 : !async.value<f32>
- // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value<f32>
+ // CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : <f32>
%0 = async.await %result : !async.value<f32>
// CHECK: return %[[VALUE]]
@@ -323,7 +323,7 @@ func @async_value_operands() {
// // Load from the async.value argument after error checking.
// CHECK: ^[[CONTINUATION:.*]]:
-// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
+// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : <f32
// CHECK: arith.addf %[[LOADED]], %[[LOADED]] : f32
// CHECK: async.runtime.set_available %[[TOKEN]]
diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index 1daa4ea64b2fc..2841fa11d36bc 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -129,16 +129,16 @@ func @resume(%arg0: !async.coro.handle) {
// CHECK-LABEL: @store
func @store(%arg0: f32, %arg1: !async.value<f32>) {
- // CHECK: async.runtime.store %arg0, %arg1 : !async.value<f32>
- async.runtime.store %arg0, %arg1 : !async.value<f32>
+ // CHECK: async.runtime.store %arg0, %arg1 : <f32>
+ async.runtime.store %arg0, %arg1 : <f32>
return
}
// CHECK-LABEL: @load
func @load(%arg0: !async.value<f32>) -> f32 {
- // CHECK: %0 = async.runtime.load %arg0 : !async.value<f32>
+ // CHECK: %0 = async.runtime.load %arg0 : <f32>
// CHECK: return %0 : f32
- %0 = async.runtime.load %arg0 : !async.value<f32>
+ %0 = async.runtime.load %arg0 : <f32>
return %0 : f32
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index b7ef524475487..934f8e3499365 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -6,7 +6,7 @@
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
+// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
outs(%C: memref<f32>)
@@ -19,7 +19,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
outs(%C: memref<1584xf32>)
@@ -31,7 +31,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
// CHECK-LABEL: contraction_matmul
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
outs(%C: memref<1584x1584xf32>)
@@ -43,7 +43,7 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
// CHECK-LABEL: contraction_batch_matmul
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
-// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
+// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
linalg.batch_matmul
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
@@ -71,7 +71,7 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
linalg.generic #matmul_trait
@@ -105,7 +105,7 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
linalg.generic #matmul_transpose_out_trait
@@ -139,7 +139,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
- // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
// CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
@@ -160,7 +160,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
+ // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
linalg.matmul
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
@@ -523,7 +523,7 @@ func @matmul_tensors(
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
// convert it to a 2D contract.
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
// CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@@ -744,7 +744,7 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
- // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
+ // CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: addf {{.*}} : vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32>
@@ -779,7 +779,7 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
- // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: addf {{.*}} : vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32>
@@ -808,7 +808,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
- // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant -3.40282e+38 : f32
@@ -833,7 +833,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
- // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%maxf32 = arith.constant 3.40282e+38 : f32
@@ -857,7 +857,7 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
- // CHECK: vector.multi_reduction #vector.kind<mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: mulf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant 1.0 : f32
@@ -881,7 +881,7 @@ func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction #vector.kind<or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -904,7 +904,7 @@ func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction #vector.kind<and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant true
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -927,7 +927,7 @@ func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
- // CHECK: vector.multi_reduction #vector.kind<xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
+ // CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@@ -979,7 +979,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
// CHECK: subf {{.*}} : vector<4x4xf32>
// CHECK: math.exp {{.*}} : vector<4x4xf32>
- // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
// CHECK: addf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
%c0 = arith.constant 0.0 : f32
@@ -1019,7 +1019,7 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
- // CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
+ // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3557cebae0af6..9b496f857b1ab 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1027,7 +1027,7 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
// CHECK-SAME: %[[v:.*]]: vector<2xf32>
func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
// CHECK: return %[[v]] : vector<2xf32>
return %0 : vector<2xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 63e30cf8912e7..195e720d2ad99 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -3,7 +3,7 @@
// -----
func @broadcast_to_scalar(%arg0: f32) -> f32 {
- // expected-error at +1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}}
+ // expected-error at +1 {{custom op 'vector.broadcast' invalid kind of type specified}}
%0 = vector.broadcast %arg0 : f32 to f32
}
@@ -1022,7 +1022,7 @@ func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
// -----
func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{must be vector of any type values}}
+ // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 43c5abdd9ef8c..a68d105186bfc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -685,9 +685,9 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
// CHECK-LABEL: @multi_reduction
func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
- %1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
+ %1 = vector.multi_reduction <add>, %0 [1, 3] :
vector<4x8x16x32xf32> to vector<4x16xf32>
- %2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
+ %2 = vector.multi_reduction <add>, %1 [0, 1] :
vector<4x16xf32> to f32
return %2 : f32
}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 7e6c1713d455e..76936c5d06e9c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
@@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]]
func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
- %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
+ %0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
@@ -30,7 +30,7 @@ func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
// CHECK: return %[[RES]]
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+ %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
// CHECK-LABEL: func @vector_reduction_inner
@@ -66,7 +66,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
- %0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+ %0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
return %0 : vector<2x5xf32>
}
@@ -78,7 +78,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
// CHECK: return %[[RESULT]]
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
- %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
+ %0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_ordering
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 3f0e184274dce..f94a0b6e1a960 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@@ -35,7 +35,7 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@@ -52,7 +52,7 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+ %0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@@ -69,7 +69,7 @@ func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+ %0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@@ -86,7 +86,7 @@ func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+ %0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@@ -104,7 +104,7 @@ func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+ %0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index ccab0617e241e..85389b2a767d6 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -12,7 +12,7 @@
func @multidimreduction_contract(
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
- %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
+ %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
}
@@ -30,7 +30,7 @@ func @multidimreduction_contract(
func @multidimreduction_contract_int(
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
- %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
+ %1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
return %1 : vector<8x16xi32>
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index a9721cf31de81..8d1723adfc681 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -14,7 +14,7 @@
#define TEST_ATTRDEFS
// To get the test dialect definition.
-include "TestOps.td"
+include "TestDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
@@ -121,6 +121,29 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
);
}
+// A more complex parameterized attribute with multiple level of nesting.
+def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> {
+ let mnemonic = "cmpnd_nested_inner";
+ // List of type parameters.
+ let parameters = (
+ ins
+ "int":$some_int,
+ CompoundAttrA:$cmpdA
+ );
+ let assemblyFormat = "`<` $some_int $cmpdA `>`";
+}
+
+def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> {
+ let mnemonic = "cmpnd_nested_outer";
+
+ // List of type parameters.
+ let parameters = (
+ ins
+ CompoundNestedInner:$inner
+ );
+ let assemblyFormat = "`<` `i` $inner `>`";
+}
+
def TestParamOne : AttrParameter<"int64_t", ""> {}
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
new file mode 100644
index 0000000000000..756c7c10c92db
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -0,0 +1,46 @@
+//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_DIALECT
+#define TEST_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "::test";
+ let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+ let hasCanonicalizer = 1;
+ let hasConstantMaterializer = 1;
+ let hasOperationAttrVerify = 1;
+ let hasRegionArgAttrVerify = 1;
+ let hasRegionResultAttrVerify = 1;
+ let hasOperationInterfaceFallback = 1;
+ let hasNonDefaultDestructor = 1;
+ let useDefaultAttributePrinterParser = 1;
+ let dependentDialects = ["::mlir::DLTIDialect"];
+
+ let extraClassDeclaration = [{
+ void registerAttributes();
+ void registerTypes();
+
+ // Provides a custom printing/parsing for some operations.
+ ::llvm::Optional<ParseOpHook>
+ getParseOperationHook(::llvm::StringRef opName) const override;
+ ::llvm::unique_function<void(::mlir::Operation *,
+ ::mlir::OpAsmPrinter &printer)>
+ getOperationPrinter(::mlir::Operation *op) const override;
+
+ private:
+ // Storage for a custom fallback interface.
+ void *fallbackEffectOpInterfaces;
+
+ }];
+}
+
+#endif // TEST_DIALECT
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2bea95017fa5f..4f6abed2eda26 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -9,6 +9,7 @@
#ifndef TEST_OPS
#define TEST_OPS
+include "TestDialect.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
@@ -23,40 +24,11 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "TestInterfaces.td"
-def Test_Dialect : Dialect {
- let name = "test";
- let cppNamespace = "::test";
- let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
- let hasCanonicalizer = 1;
- let hasConstantMaterializer = 1;
- let hasOperationAttrVerify = 1;
- let hasRegionArgAttrVerify = 1;
- let hasRegionResultAttrVerify = 1;
- let hasOperationInterfaceFallback = 1;
- let hasNonDefaultDestructor = 1;
- let useDefaultAttributePrinterParser = 1;
- let dependentDialects = ["::mlir::DLTIDialect"];
-
- let extraClassDeclaration = [{
- void registerAttributes();
- void registerTypes();
-
- // Provides a custom printing/parsing for some operations.
- ::llvm::Optional<ParseOpHook>
- getParseOperationHook(::llvm::StringRef opName) const override;
- ::llvm::unique_function<void(::mlir::Operation *,
- ::mlir::OpAsmPrinter &printer)>
- getOperationPrinter(::mlir::Operation *op) const override;
-
- private:
- // Storage for a custom fallback interface.
- void *fallbackEffectOpInterfaces;
-
- }];
-}
// Include the attribute definitions.
include "TestAttrDefs.td"
+// Include the type definitions.
+include "TestTypeDefs.td"
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
@@ -1933,6 +1905,16 @@ def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
let assemblyFormat = "$nested attr-dict-with-keyword";
}
+def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
+ let arguments = (ins CompoundNestedOuter:$nested);
+ let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
+}
+
+def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
+ let arguments = (ins CompoundNestedOuterType:$nested);
+ let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
+}
+
//===----------------------------------------------------------------------===//
// Custom Directives
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 553fc54e9acae..7bd77e9ce42be 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -14,8 +14,9 @@
#define TEST_TYPEDEFS
// To get the test dialect def.
-include "TestOps.td"
+include "TestDialect.td"
include "TestAttrDefs.td"
+include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -49,6 +50,29 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
}];
}
+// A more complex and nested parameterized type.
+def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> {
+ let mnemonic = "cmpnd_inner";
+ // List of type parameters.
+ let parameters = (
+ ins
+ "int":$some_int,
+ CompoundTypeA:$cmpdA
+ );
+ let assemblyFormat = "`<` $some_int $cmpdA `>`";
+}
+
+def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> {
+ let mnemonic = "cmpnd_nested_outer";
+
+ // List of type parameters.
+ let parameters = (
+ ins
+ CompoundNestedInnerType:$inner
+ );
+ let assemblyFormat = "`<` `i` $inner `>`";
+}
+
// An example of how one could implement a standard integer.
def IntegerType : Test_Type<"TestInteger"> {
let mnemonic = "int";
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 22a18c5461d50..0b050034a6b43 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -25,6 +25,7 @@
#include "mlir/Interfaces/DataLayoutInterfaces.h"
namespace test {
+class TestAttrWithFormatAttr;
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
@@ -63,13 +64,13 @@ struct FieldParser<test::CustomParam> {
return test::CustomParam{value.getValue()};
}
};
-} // end namespace mlir
-
inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
- const test::CustomParam ¶m) {
+ test::CustomParam param) {
return printer << param.value;
}
+} // end namespace mlir
+
#include "TestTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 888b9856da761..a0588bfb237df 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -61,7 +61,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
// ATTR: printer << ' ' << "hello";
// ATTR: printer << ' ' << "=";
// ATTR: printer << ' ';
-// ATTR: printer << getValue();
+// ATTR: printer.printStrippedAttrOrType(getValue());
// ATTR: printer << ",";
// ATTR: printer << ' ';
// ATTR: ::printAttrParamA(printer, getComplex());
@@ -154,10 +154,10 @@ def AttrB : TestAttr<"TestB"> {
// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
// ATTR: printer << ' ';
-// ATTR: printer << getV0();
+// ATTR: printer.printStrippedAttrOrType(getV0());
// ATTR: printer << ",";
// ATTR: printer << ' ';
-// ATTR: printer << getV1();
+// ATTR: printer.printStrippedAttrOrType(getV1());
// ATTR: }
def AttrC : TestAttr<"TestF"> {
@@ -213,7 +213,7 @@ def AttrC : TestAttr<"TestF"> {
// TYPE: printer << ' ' << "bob";
// TYPE: printer << ' ' << "bar";
// TYPE: printer << ' ';
-// TYPE: printer << getValue();
+// TYPE: printer.printStrippedAttrOrType(getValue());
// TYPE: printer << ' ' << "complex";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
@@ -361,21 +361,21 @@ def TypeB : TestType<"TestD"> {
// TYPE: printer << "v0";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
-// TYPE: printer << getV0();
+// TYPE: printer.printStrippedAttrOrType(getV0());
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v2";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
-// TYPE: printer << getV2();
+// TYPE: printer.printStrippedAttrOrType(getV2());
// TYPE: printer << "v1";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
-// TYPE: printer << getV1();
+// TYPE: printer.printStrippedAttrOrType(getV1());
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v3";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
-// TYPE: printer << getV3();
+// TYPE: printer.printStrippedAttrOrType(getV3());
// TYPE: }
def TypeC : TestType<"TestE"> {
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index dd996b8acc6bc..c3214c7afab4d 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -256,16 +256,50 @@ test.format_optional_else else
// Format a custom attribute
//===----------------------------------------------------------------------===//
-// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
-test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
+// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]>
+test.format_compound_attr <1, !test.smpla, [5, 6]>
-// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
+//-----
+
+
+// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
+module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
+}
+
+//-----
+
+// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
+// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
}
-// CHECK: test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
+// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
+test.format_nested_attr #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>
+
+//-----
+
+// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
+// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
+//-----
+
+// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
+module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
+{
+}
+
+//-----
+
+// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
+module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
+{
+}
+
+//-----
+
+// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
+test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
//===----------------------------------------------------------------------===//
// Format custom directives
diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index 783b4be704c77..4ab6b0e86a279 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -13,6 +13,22 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
return
}
+// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
+func @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>) -> () {
+ return
+}
+
+// Same as above, but we're parsing the complete spec for the inner type
+// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
+func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> () {
+// Verify that the type prefix is elided and optional
+// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+ test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>
+ test.format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
+ return
+}
+
// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
return
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 01ae08e57be9e..94e81f1085b5b 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -163,7 +163,8 @@ static const char *const defaultParameterParser =
"::mlir::FieldParser<$0>::parse($_parser)";
/// Default printer for attribute or type parameters.
-static const char *const defaultParameterPrinter = "$_printer << $_self";
+static const char *const defaultParameterPrinter =
+ "$_printer.printStrippedAttrOrType($_self)";
/// Print an error when failing to parse an element.
///
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0827146da180f..08785516b99f2 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -496,13 +496,25 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
- if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
+ if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
+ result.attributes)) {{
+ return ::mlir::failure();
+ }
+)";
+
+/// The code snippet used to generate a parser call for an attribute.
+///
+/// {0}: The name of the attribute.
+/// {1}: The type for the attribute.
+const char *const genericAttrParserCode = R"(
+ if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
return ::mlir::failure();
)";
+
const char *const optionalAttrParserCode = R"(
{
::mlir::OptionalParseResult parseResult =
- parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
+ parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return ::mlir::failure();
}
@@ -635,8 +647,12 @@ const char *const optionalTypeParserCode = R"(
}
)";
const char *const typeParserCode = R"(
- if (parser.parseType({0}RawTypes[0]))
- return ::mlir::failure();
+ {
+ {0} type;
+ if (parser.parseCustomTypeWithFallback(type))
+ return ::mlir::failure();
+ {1}RawTypes[0] = type;
+ }
)";
/// The code snippet used to generate a parser call for a functional type.
@@ -1269,12 +1285,19 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
std::string attrTypeStr;
if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
llvm::raw_string_ostream os(attrTypeStr);
- os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
+ os << tgfmt(*typeBuilder, &attrTypeCtx);
+ } else {
+ attrTypeStr = "Type{}";
+ }
+ if (var->attr.isOptional()) {
+ body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ } else {
+ if (var->attr.getStorageType() == "::mlir::Attribute")
+ body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+ else
+ body << formatv(attrParserCode, var->name, attrTypeStr);
}
- body << formatv(var->attr.isOptional() ? optionalAttrParserCode
- : attrParserCode,
- var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@@ -1334,14 +1357,23 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
- if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+ if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
- else if (lengthKind == ArgumentLengthKind::Variadic)
+ } else if (lengthKind == ArgumentLengthKind::Variadic) {
body << llvm::formatv(variadicTypeParserCode, listName);
- else if (lengthKind == ArgumentLengthKind::Optional)
+ } else if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(optionalTypeParserCode, listName);
- else
- body << formatv(typeParserCode, listName);
+ } else {
+ TypeSwitch<Element *>(dir->getOperand())
+ .Case<OperandVariable, ResultVariable>([&](auto operand) {
+ body << formatv(typeParserCode,
+ operand->getVar()->constraint.getCPPClassName(),
+ listName);
+ })
+ .Default([&](auto operand) {
+ body << formatv(typeParserCode, "Type", listName);
+ });
+ }
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
ArgumentLengthKind ignored;
body << formatv(functionalTypeParserCode,
@@ -1761,7 +1793,8 @@ static void genVariadicRegionPrinter(const Twine ®ionListName,
/// Generate the C++ for an operand to a (*-)type directive.
static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
- MethodBody &body) {
+ MethodBody &body,
+ bool useArrayRef = true) {
if (isa<OperandsDirective>(arg))
return body << "getOperation()->getOperandTypes()";
if (isa<ResultsDirective>(arg))
@@ -1778,8 +1811,10 @@ static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
"::llvm::ArrayRef<::mlir::Type>())",
op.getGetterName(var->name));
- return body << "::llvm::ArrayRef<::mlir::Type>("
- << op.getGetterName(var->name) << "().getType())";
+ if (useArrayRef)
+ return body << "::llvm::ArrayRef<::mlir::Type>("
+ << op.getGetterName(var->name) << "().getType())";
+ return body << op.getGetterName(var->name) << "().getType()";
}
/// Generate the printer for an enum attribute.
@@ -1978,9 +2013,15 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
if (attr->getTypeBuilder())
body << " _odsPrinter.printAttributeWithoutType("
<< op.getGetterName(var->name) << "Attr());\n";
- else
+ else if (var->attr.isOptional())
+ body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name)
+ << "Attr());\n";
+ else if (var->attr.getStorageType() == "::mlir::Attribute")
body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
<< "Attr());\n";
+ else
+ body << "_odsPrinter.printStrippedAttrOrType("
+ << op.getGetterName(var->name) << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << " ::llvm::interleaveComma("
@@ -2033,8 +2074,29 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
return;
}
}
+ const NamedTypeConstraint *var = nullptr;
+ {
+ if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
+ var = operand->getVar();
+ else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
+ var = operand->getVar();
+ }
+ if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
+ !var->isOptional()) {
+ std::string cppClass = var->constraint.getCPPClassName();
+ body << " {\n"
+ << " auto type = " << op.getGetterName(var->name)
+ << "().getType();\n"
+ << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
+ << " _odsPrinter.printStrippedAttrOrType(validType);\n"
+ << " else\n"
+ << " _odsPrinter << type;\n"
+ << " }\n";
+ return;
+ }
body << " _odsPrinter << ";
- genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n";
+ genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
+ << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
body << " _odsPrinter.printFunctionalType(";
genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
More information about the Mlir-commits
mailing list