[Mlir-commits] [mlir] d572cd1 - Introduce MLIR Op Properties
Mehdi Amini
llvmlistbot at llvm.org
Mon May 1 15:36:14 PDT 2023
Author: Mehdi Amini
Date: 2023-05-01T15:35:48-07:00
New Revision: d572cd1b067f1177a981a4711bf2e501eaa8117b
URL: https://github.com/llvm/llvm-project/commit/d572cd1b067f1177a981a4711bf2e501eaa8117b
DIFF: https://github.com/llvm/llvm-project/commit/d572cd1b067f1177a981a4711bf2e501eaa8117b.diff
LOG: Introduce MLIR Op Properties
This new features enabled to dedicate custom storage inline within operations.
This storage can be used as an alternative to attributes to store data that is
specific to an operation. Attribute can also be stored inside the properties
storage if desired, but any kind of data can be present as well. This offers
a way to store and mutate data without uniquing in the Context like Attribute.
See the OpPropertiesTest.cpp for an example where a struct with a
std::vector<> is attached to an operation and mutated in-place:
struct TestProperties {
int a = -1;
float b = -1.;
std::vector<int64_t> array = {-33};
};
More complex scheme (including reference-counting) are also possible.
The only constraint to enable storing a C++ object as "properties" on an
operation is to implement three functions:
- convert from the candidate object to an Attribute
- convert from the Attribute to the candidate object
- hash the object
Optional the parsing and printing can also be customized with 2 extra
functions.
A new options is introduced to ODS to allow dialects to specify:
let usePropertiesForAttributes = 1;
When set to true, the inherent attributes for all the ops in this dialect
will be using properties instead of being stored alongside discardable
attributes.
The TestDialect showcases this feature.
Another change is that we introduce new APIs on the Operation class
to access separately the inherent attributes from the discardable ones.
We envision deprecating and removing the `getAttr()`, `getAttrsDictionary()`,
and other similar method which don't make the distinction explicit, leading
to an entirely separate namespace for discardable attributes.
Differential Revision: https://reviews.llvm.org/D141742
Added:
mlir/include/mlir/IR/ODSSupport.h
mlir/include/mlir/TableGen/Property.h
mlir/lib/IR/ODSSupport.cpp
mlir/lib/TableGen/Property.cpp
mlir/test/IR/properties.mlir
mlir/unittests/IR/OpPropertiesTest.cpp
Modified:
mlir/docs/LangRef.md
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/IR/DialectBase.td
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/include/mlir/TableGen/Argument.h
mlir/include/mlir/TableGen/Class.h
mlir/include/mlir/TableGen/Dialect.h
mlir/include/mlir/TableGen/Operator.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/include/mlir/Transforms/OneToNTypeConversion.h
mlir/lib/AsmParser/Parser.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/CodeGenHelpers.cpp
mlir/lib/TableGen/Dialect.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/Bytecode/versioning/versioned_attr.mlir
mlir/test/Bytecode/versioning/versioned_op.mlir
mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/IR/greedy-pattern-rewriter-driver.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/IR/test-fold-adaptor.mlir
mlir/test/IR/test-manual-cpp-fold.mlir
mlir/test/IR/traits.mlir
mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
mlir/test/Transforms/decompose-call-graph-types.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/Transforms/test-operation-folder.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestDialect.td
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/mlir-tblgen/constraint-unique.td
mlir/test/mlir-tblgen/interfaces-as-constraints.td
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/test/mlir-tblgen/op-format.td
mlir/test/mlir-tblgen/op-result.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/test/mlir-tblgen/return-types.mlir
mlir/tools/mlir-tblgen/FormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.h
mlir/tools/mlir-tblgen/OpClass.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
mlir/unittests/IR/AdaptorTest.cpp
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/OperationSupportTest.cpp
mlir/unittests/Transforms/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 5ce86eb88dc61..48e4903216531 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -290,12 +290,14 @@ Syntax:
operation ::= op-result-list? (generic-operation | custom-operation)
trailing-location?
generic-operation ::= string-literal `(` value-use-list? `)` successor-list?
- region-list? dictionary-attribute? `:` function-type
+ dictionary-properties? region-list? dictionary-attribute?
+ `:` function-type
custom-operation ::= bare-id custom-operation-format
op-result-list ::= op-result (`,` op-result)* `=`
op-result ::= value-id (`:` integer-literal)
successor-list ::= `[` successor (`,` successor)* `]`
successor ::= caret-id (`:` block-arg-list)?
+dictionary-propertes ::= `<` dictionary-attribute `>`
region-list ::= `(` region (`,` region)* `)`
dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}`
trailing-location ::= (`loc` `(` location `)`)?
@@ -312,9 +314,10 @@ semantics. For example, MLIR supports
The internal representation of an operation is simple: an operation is
identified by a unique string (e.g. `dim`, `tf.Conv2d`, `x86.repmovsb`,
`ppc.eieio`, etc), can return zero or more results, take zero or more operands,
-has a dictionary of [attributes](#attributes), has zero or more successors, and
-zero or more enclosed [regions](#regions). The generic printing form includes
-all these elements literally, with a function type to indicate the types of the
+has storage for [properties](#properties), has a dictionary of
+[attributes](#attributes), has zero or more successors, and zero or more
+enclosed [regions](#regions). The generic printing form includes all these
+elements literally, with a function type to indicate the types of the
results and operands.
Example:
@@ -328,8 +331,11 @@ Example:
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
-// and an attribute "fruit".
-%2 = "tf.scramble"(%result#0, %bar) {fruit = "banana"} : (f32, i32) -> f32
+// and an attribute "fruit" stored in properties.
+%2 = "tf.scramble"(%result#0, %bar) <{fruit = "banana"}> : (f32, i32) -> f32
+
+// Invoke an operation with some discardable attributes
+%foo, %bar = "foo_div"() {some_attr = "value", other_attr = 42 : i64} : () -> (f32, i32)
```
In addition to the basic syntax above, dialects may register known operations.
@@ -733,6 +739,15 @@ The [builtin dialect](Dialects/Builtin.md) defines a set of types that are
directly usable by any other dialect in MLIR. These types cover a range from
primitive integer and floating-point types, function types, and more.
+## Properties
+
+Properties are extra data members stored directly on an Operation class. They
+provide a way to store [inherent attributes](#attributes) and other arbitrary
+data. The semantics of the data is specific to a given operation, and may be
+exposed through [Interfaces](Interfaces.md) accessors and other methods.
+Properties can always be serialized to Attribute in order to be printed
+generically.
+
## Attributes
Syntax:
@@ -751,9 +766,10 @@ values. MLIR's builtin dialect provides a rich set of
arrays, dictionaries, strings, etc.). Additionally, dialects can define their
own [dialect attribute values](#dialect-attribute-values).
-The top-level attribute dictionary attached to an operation has special
-semantics. The attribute entries are considered to be of two
diff erent kinds
-based on whether their dictionary key has a dialect prefix:
+For dialects which haven't adopted properties yet, the top-level attribute
+dictionary attached to an operation has special semantics. The attribute
+entries are considered to be of two
diff erent kinds based on whether their
+dictionary key has a dialect prefix:
- *inherent attributes* are inherent to the definition of an operation's
semantics. The operation itself is expected to verify the consistency of
@@ -771,6 +787,10 @@ Note that attribute values are allowed to themselves be dictionary attributes,
but only the top-level dictionary attribute attached to the operation is subject
to the classification above.
+When properties are adopted, only discardable attributes are stored in the
+top-level dictionary, while inherent attributes are stored in the properties
+storage.
+
### Attribute Value Aliases
```
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 90e62aa11787c..f17a2e54a7313 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -145,6 +145,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ if constexpr (SourceOp::hasProperties())
+ rewrite(cast<SourceOp>(op),
+ OpAdaptor(operands, op->getAttrDictionary(),
+ cast<SourceOp>(op).getProperties()),
+ rewriter);
rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
}
@@ -154,6 +159,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ if constexpr (SourceOp::hasProperties())
+ return matchAndRewrite(cast<SourceOp>(op),
+ OpAdaptor(operands, op->getAttrDictionary(),
+ cast<SourceOp>(op).getProperties()),
+ rewriter);
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 3d665da7f04f9..ff8c2beb65757 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -103,6 +103,9 @@ class Dialect {
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
+
+ // Whether inherent Attributes defined in ODS will be stored as Properties.
+ bit usePropertiesForAttributes = 0;
}
#endif // DIALECTBASE_TD
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index cb8f4fcfa7224..4ca27ea509868 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -26,6 +26,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <optional>
namespace mlir {
class AsmParser;
@@ -462,6 +464,35 @@ class DynamicOpDefinition : public OperationName::Impl {
return verifyRegionFn(op);
}
+ /// Implementation for properties (unsupported right now here).
+ std::optional<Attribute> getInherentAttr(Operation *op,
+ StringRef name) final {
+ llvm::report_fatal_error("Unsupported getInherentAttr on Dynamic dialects");
+ }
+ void setInherentAttr(Operation *op, StringAttr name, Attribute value) final {
+ llvm::report_fatal_error("Unsupported setInherentAttr on Dynamic dialects");
+ }
+ void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {}
+ LogicalResult
+ verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) final {
+ return success();
+ }
+ int getOpPropertyByteSize() final { return 0; }
+ void initProperties(OperationName opName, OpaqueProperties storage,
+ OpaqueProperties init) final {}
+ void deleteProperties(OpaqueProperties prop) final {}
+ void populateDefaultProperties(OperationName opName,
+ OpaqueProperties properties) final {}
+
+ LogicalResult setPropertiesFromAttr(Operation *op, Attribute attr,
+ InFlightDiagnostic *diag) final {
+ return failure();
+ }
+ Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
+ void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
+ llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
+
private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h
new file mode 100644
index 0000000000000..1d3cbbd690034
--- /dev/null
+++ b/mlir/include/mlir/IR/ODSSupport.h
@@ -0,0 +1,45 @@
+//===- ODSSupport.h ---------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a number of support method for ODS generated code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_ODSSUPPORT_H
+#define MLIR_IR_ODSSUPPORT_H
+
+#include "mlir/IR/Attributes.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// Support for properties
+//===----------------------------------------------------------------------===//
+
+/// Convert an IntegerAttr attribute to an int64_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult convertFromAttribute(int64_t &storage, Attribute attr,
+ InFlightDiagnostic *diag);
+
+/// Convert the provided int64_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
+
+/// Convert a DenseI64ArrayAttr to the provided storage. It is expected that the
+/// storage has the same size as the array. An error is returned if the
+/// attribute isn't a DenseI64ArrayAttr or it does not have the same size. If
+/// the optional diagnostic is provided an error message is also emitted.
+LogicalResult convertFromAttribute(MutableArrayRef<int64_t> storage,
+ Attribute attr, InFlightDiagnostic *diag);
+
+/// Convert the provided ArrayRef<int64_t> to a DenseI64ArrayAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, ArrayRef<int64_t> storage);
+
+} // namespace mlir
+
+#endif // MLIR_IR_ODSSUPPORT_H
\ No newline at end of file
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 5d0843cdf7529..b02c0eafa0c24 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -179,6 +179,69 @@ class TypeConstraint<Pred predicate, string summary = "",
string cppClassName = cppClassNameParam;
}
+// Base class for defining properties.
+class Property<string storageTypeParam = "", string desc = ""> {
+ // User-readable one line summary used in error reporting messages. If empty,
+ // a generic message will be used.
+ string summary = desc;
+ // The full description of this property.
+ string description = "";
+ code storageType = storageTypeParam;
+ code interfaceType = storageTypeParam;
+
+ // The expression to convert from the storage type to the Interface
+ // type. For example, an enum can be stored as an int but returned as an
+ // enum class.
+ //
+ // Format:
+ // - `$_storage` will contain the property in the storage type.
+ // - `$_ctxt` will contain an `MLIRContext *`.
+ code convertFromStorage = "$_storage";
+
+ // The call expression to build a property storage from the interface type.
+ //
+ // Format:
+ // - `$_storage` will contain the property in the storage type.
+ // - `$_value` will contain the property in the user interface type.
+ code assignToStorage = "$_storage = $_value";
+
+ // The call expression to convert from the storage type to an attribute.
+ //
+ // Format:
+ // - `$_storage` is the storage type value.
+ // - `$_ctxt` is a `MLIRContext *`.
+ //
+ // The expression must result in an Attribute.
+ code convertToAttribute = [{
+ convertToAttribute($_ctxt, $_storage)
+ }];
+
+ // The call expression to convert from an Attribute to the storage type.
+ //
+ // Format:
+ // - `$_storage` is the storage type value.
+ // - `$_attr` is the attribute.
+ // - `$_diag` is an optional Diagnostic pointer to emit error.
+ //
+ // The expression must return a LogicalResult
+ code convertFromAttribute = [{
+ return convertFromAttribute($_storage, $_attr, $_diag);
+ }];
+
+ // The call expression to hash the property.
+ //
+ // Format:
+ // - `$_storage` is the variable to hash.
+ //
+ // The expression should define a llvm::hash_code.
+ code hashProperty = [{
+ llvm::hash_value($_storage);
+ }];
+
+ // Default value for the property.
+ string defaultValue = ?;
+}
+
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
@@ -1090,6 +1153,16 @@ class DefaultValuedStrAttr<Attr attr, string val>
class DefaultValuedOptionalStrAttr<Attr attr, string val>
: DefaultValuedOptionalAttr<attr, "\"" # val # "\"">;
+//===----------------------------------------------------------------------===//
+// Primitive property kinds
+
+class ArrayProperty<string storageTypeParam = "", int n, string desc = ""> :
+ Property<storageTypeParam # "[" # n # "]", desc> {
+ let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">";
+ let convertFromStorage = "$_storage";
+ let assignToStorage = "::llvm::copy($_value, $_storage)";
+}
+
//===----------------------------------------------------------------------===//
// Primitive attribute kinds
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 5a73d776996f9..d08d3de350f3e 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -71,6 +71,21 @@ void ensureRegionTerminator(
} // namespace impl
+/// Structure used by default as a "marker" when no "Properties" are set on an
+/// Operation.
+struct EmptyProperties {};
+
+/// Traits to detect whether an Operation defined a `Properties` type, otherwise
+/// it'll default to `EmptyProperties`.
+template <class Op, class = void>
+struct PropertiesSelector {
+ using type = EmptyProperties;
+};
+template <class Op>
+struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> {
+ using type = typename Op::Properties;
+};
+
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
@@ -206,6 +221,13 @@ class OpState {
/// in generic form.
static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
+ /// Parse properties as a Attribute.
+ static ParseResult genericParseProperties(OpAsmParser &parser,
+ Attribute &result);
+
+ /// Print the properties as a Attribute.
+ static void genericPrintProperties(OpAsmPrinter &p, Attribute properties);
+
/// Print an operation name, eliding the dialect prefix if necessary.
static void printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect);
@@ -214,6 +236,14 @@ class OpState {
/// so we can cast it away here.
explicit OpState(Operation *state) : state(state) {}
+ /// For all op which don't have properties, we keep a single instance of
+ /// `EmptyProperties` to be used where a reference to a properties is needed:
+ /// this allow to bind a pointer to the reference without triggering UB.
+ static EmptyProperties &getEmptyProperties() {
+ static EmptyProperties emptyProperties;
+ return emptyProperties;
+ }
+
private:
Operation *state;
@@ -1471,13 +1501,17 @@ namespace op_definition_impl {
/// Returns true if this given Trait ID matches the IDs of any of the provided
/// trait types `Traits`.
template <template <typename T> class... Traits>
-static bool hasTrait(TypeID traitID) {
+inline bool hasTrait(TypeID traitID) {
TypeID traitIDs[] = {TypeID::get<Traits>()...};
for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
if (traitIDs[i] == traitID)
return true;
return false;
}
+template <>
+inline bool hasTrait<>(TypeID traitID) {
+ return false;
+}
//===----------------------------------------------------------------------===//
// Trait Folding
@@ -1693,6 +1727,33 @@ class Op : public OpState, public Traits<ConcreteType>... {
(checkInterfaceTarget<Models>(), ...);
info->attachInterface<Models...>();
}
+ /// Convert the provided attribute to a property and assigned it to the
+ /// provided properties. This default implementation forwards to a free
+ /// function `setPropertiesFromAttribute` that can be looked up with ADL in
+ /// the namespace where the properties are defined. It can also be overridden
+ /// in the derived ConcreteOp.
+ template <typename PropertiesTy>
+ static LogicalResult setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
+ InFlightDiagnostic *diag) {
+ return setPropertiesFromAttribute(prop, attr, diag);
+ }
+ /// Convert the provided properties to an attribute. This default
+ /// implementation forwards to a free function `getPropertiesAsAttribute` that
+ /// can be looked up with ADL in the namespace where the properties are
+ /// defined. It can also be overridden in the derived ConcreteOp.
+ template <typename PropertiesTy>
+ static Attribute getPropertiesAsAttr(MLIRContext *ctx,
+ const PropertiesTy &prop) {
+ return getPropertiesAsAttribute(ctx, prop);
+ }
+ /// Hash the provided properties. This default implementation forwards to a
+ /// free function `computeHash` that can be looked up with ADL in the
+ /// namespace where the properties are defined. It can also be overridden in
+ /// the derived ConcreteOp.
+ template <typename PropertiesTy>
+ static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) {
+ return computeHash(prop);
+ }
private:
/// Trait to check if T provides a 'fold' method for a single result op.
@@ -1733,10 +1794,35 @@ class Op : public OpState, public Traits<ConcreteType>... {
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
+ /// Trait to check if printProperties(OpAsmPrinter, T) exist
+ template <typename T, typename... Args>
+ using has_print_properties = decltype(printProperties(
+ std::declval<OpAsmPrinter &>(), std::declval<T>()));
+ template <typename T>
+ using detect_has_print_properties =
+ llvm::is_detected<has_print_properties, T>;
+
+ /// Trait to check if parseProperties(OpAsmParser, T) exist
+ template <typename T, typename... Args>
+ using has_parse_properties = decltype(parseProperties(
+ std::declval<OpAsmParser &>(), std::declval<T &>()));
+ template <typename T>
+ using detect_has_parse_properties =
+ llvm::is_detected<has_parse_properties, T>;
+
/// Trait to check if T provides a 'ConcreteEntity' type alias.
template <typename T>
using has_concrete_entity_t = typename T::ConcreteEntity;
+public:
+ /// Returns true if this operation defines a `Properties` inner type.
+ static constexpr bool hasProperties() {
+ return !std::is_same_v<
+ typename ConcreteType::template InferredProperties<ConcreteType>,
+ EmptyProperties>;
+ }
+
+private:
/// A struct-wrapped type alias to T::ConcreteEntity if provided and to
/// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
/// on the missing typedef. Useful for checking if the interface is targeting
@@ -1801,11 +1887,18 @@ class Op : public OpState, public Traits<ConcreteType>... {
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
OpFoldResult result;
- if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
- result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(), op->getRegions()));
- else
+ if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
+ if constexpr (hasProperties()) {
+ result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(),
+ cast<ConcreteOpT>(op).getProperties(), op->getRegions()));
+ } else {
+ result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(), {}, op->getRegions()));
+ }
+ } else {
result = cast<ConcreteOpT>(op).fold(operands);
+ }
// If the fold failed or was in-place, try to fold the traits of the
// operation.
@@ -1824,10 +1917,18 @@ class Op : public OpState, public Traits<ConcreteType>... {
SmallVectorImpl<OpFoldResult> &results) {
auto result = LogicalResult::failure();
if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
- result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
- op->getRegions()),
- results);
+ if constexpr (hasProperties()) {
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(),
+ cast<ConcreteOpT>(op).getProperties(), op->getRegions()),
+ results);
+ } else {
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+ {}, op->getRegions()),
+ results);
+ }
} else {
result = cast<ConcreteOpT>(op).fold(operands, results);
}
@@ -1859,6 +1960,48 @@ class Op : public OpState, public Traits<ConcreteType>... {
};
}
+public:
+ template <typename T>
+ using InferredProperties = typename PropertiesSelector<T>::type;
+ template <typename T = ConcreteType>
+ InferredProperties<T> &getProperties() {
+ if constexpr (!hasProperties())
+ return getEmptyProperties();
+ return *getOperation()
+ ->getPropertiesStorage()
+ .template as<InferredProperties<T> *>();
+ }
+
+ /// This hook populates any unset default attrs when mapped to properties.
+ template <typename T = ConcreteType>
+ static void populateDefaultProperties(OperationName opName,
+ InferredProperties<T> &properties) {}
+
+ /// Print the operation properties. Unless overridden, this method will try to
+ /// dispatch to a `printProperties` free-function if it exists, and otherwise
+ /// by converting the properties to an Attribute.
+ template <typename T>
+ static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
+ const T &properties) {
+ if constexpr (detect_has_print_properties<T>::value)
+ return printProperties(p, properties);
+ genericPrintProperties(p,
+ ConcreteType::getPropertiesAsAttr(ctx, properties));
+ }
+
+ /// Parser the properties. Unless overridden, this method will print by
+ /// converting the properties to an Attribute.
+ template <typename T = ConcreteType>
+ static ParseResult parseProperties(OpAsmParser &parser,
+ OperationState &result) {
+ if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) {
+ return parseProperties(
+ parser, result.getOrAddProperties<InferredProperties<T>>());
+ }
+ return genericParseProperties(parser, result.propertiesAttr);
+ }
+
+private:
/// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
return ConcreteType::populateDefaultAttrs;
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 02abca33f9460..e770c7453129b 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -957,13 +957,13 @@ class AsmParser {
/// populated in `result`.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
- parseCustomAttributeWithFallback(AttrType &result) {
+ parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
- attr, {}, [&](Attribute &result, Type type) -> ParseResult {
+ attr, type, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
return success(!!result);
}))
@@ -979,8 +979,8 @@ class AsmParser {
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
- parseCustomAttributeWithFallback(AttrType &result) {
- return parseAttribute(result);
+ parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
+ return parseAttribute(result, type);
}
/// Parse an arbitrary optional attribute of a given type and return it in
@@ -1368,6 +1368,7 @@ class OpAsmParser : public AsmParser {
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
+ std::optional<Attribute> parsedPropertiesAttribute = std::nullopt,
std::optional<FunctionType> parsedFnType = std::nullopt) = 0;
/// Parse a single SSA value operand name along with a result number if
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 632e0677483e5..8bd23a2a1fc57 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -22,6 +22,12 @@
#include <optional>
namespace mlir {
+namespace detail {
+/// This is a "tag" used for mapping the properties storage in
+/// llvm::TrailingObjects.
+enum class OpProperties : char {};
+} // namespace detail
+
/// Operation is the basic unit of execution within MLIR.
///
/// The following documentation are recommended to understand this class:
@@ -67,26 +73,35 @@ namespace mlir {
/// Some operations like branches also refer to other Block, in which case they
/// would have an array of `BlockOperand`.
///
+/// An Operation may contain optionally a "Properties" object: this is a
+/// pre-defined C++ object with a fixed size. This object is owned by the
+/// operation and deleted with the operation. It can be converted to an
+/// Attribute on demand, or loaded from an Attribute.
+///
+///
/// Finally an Operation also contain an optional `DictionaryAttr`, a Location,
/// and a pointer to its parent Block (if any).
class alignas(8) Operation final
: public llvm::ilist_node_with_parent<Operation, Block>,
private llvm::TrailingObjects<Operation, detail::OperandStorage,
- BlockOperand, Region, OpOperand> {
+ detail::OpProperties, BlockOperand, Region,
+ OpOperand> {
public:
/// Create a new Operation with the specific fields. This constructor
/// populates the provided attribute list with default attributes if
/// necessary.
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
- NamedAttrList &&attributes, BlockRange successors,
+ NamedAttrList &&attributes,
+ OpaqueProperties properties, BlockRange successors,
unsigned numRegions);
/// Create a new Operation with the specific fields. This constructor uses an
/// existing attribute dictionary to avoid uniquing a list of attributes.
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
- DictionaryAttr attributes, BlockRange successors,
+ DictionaryAttr attributes,
+ OpaqueProperties properties, BlockRange successors,
unsigned numRegions);
/// Create a new Operation from the fields stored in `state`.
@@ -96,6 +111,7 @@ class alignas(8) Operation final
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
+ OpaqueProperties properties,
BlockRange successors = {},
RegionRange regions = {});
@@ -414,24 +430,82 @@ class alignas(8) Operation final
// constants to names. Attributes may be dynamically added and removed over
// the lifetime of an operation.
+ /// Access an inherent attribute by name: returns an empty optional if there
+ /// is no inherent attribute with this name.
+ ///
+ /// This method is available as a transient facility in the migration process
+ /// to use Properties instead.
+ std::optional<Attribute> getInherentAttr(StringRef name);
+
+ /// Set an inherent attribute by name.
+ ///
+ /// This method is available as a transient facility in the migration process
+ /// to use Properties instead.
+ void setInherentAttr(StringAttr name, Attribute value);
+
+ /// Access a discardable attribute by name, returns an null Attribute if the
+ /// discardable attribute does not exist.
+ Attribute getDiscardableAttr(StringRef name) { return attrs.get(name); }
+
+ /// Access a discardable attribute by name, returns an null Attribute if the
+ /// discardable attribute does not exist.
+ Attribute getDiscardableAttr(StringAttr name) { return attrs.get(name); }
+
+ /// Set a discardable attribute by name.
+ void setDiscardableAttr(StringAttr name, Attribute value) {
+ NamedAttrList attributes(attrs);
+ if (attributes.set(name, value) != value)
+ attrs = attributes.getDictionary(getContext());
+ }
+
+ /// Return all of the discardable attributes on this operation.
+ ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
+
+ /// Return all of the discardable attributes on this operation as a
+ /// DictionaryAttr.
+ DictionaryAttr getDiscardableAttrDictionary() { return attrs; }
+
/// Return all of the attributes on this operation.
- ArrayRef<NamedAttribute> getAttrs() { return attrs.getValue(); }
+ ArrayRef<NamedAttribute> getAttrs() {
+ if (!getPropertiesStorage())
+ return getDiscardableAttrs();
+ return getAttrDictionary().getValue();
+ }
/// Return all of the attributes on this operation as a DictionaryAttr.
- DictionaryAttr getAttrDictionary() { return attrs; }
-
- /// Set the attribute dictionary on this operation.
- void setAttrs(DictionaryAttr newAttrs) {
+ DictionaryAttr getAttrDictionary();
+
+ /// Set the attributes from a dictionary on this operation.
+ /// These methods are expensive: if the dictionnary only contains discardable
+ /// attributes, `setDiscardableAttrs` is more efficient.
+ void setAttrs(DictionaryAttr newAttrs);
+ void setAttrs(ArrayRef<NamedAttribute> newAttrs);
+ /// Set the discardable attribute dictionary on this operation.
+ void setDiscardableAttrs(DictionaryAttr newAttrs) {
assert(newAttrs && "expected valid attribute dictionary");
attrs = newAttrs;
}
- void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
- setAttrs(DictionaryAttr::get(getContext(), newAttrs));
+ void setDiscardableAttrs(ArrayRef<NamedAttribute> newAttrs) {
+ setDiscardableAttrs(DictionaryAttr::get(getContext(), newAttrs));
}
/// Return the specified attribute if present, null otherwise.
- Attribute getAttr(StringAttr name) { return attrs.get(name); }
- Attribute getAttr(StringRef name) { return attrs.get(name); }
+ /// These methods are expensive: if the dictionnary only contains discardable
+ /// attributes, `getDiscardableAttr` is more efficient.
+ Attribute getAttr(StringAttr name) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
+ return *inherentAttr;
+ }
+ return attrs.get(name);
+ }
+ Attribute getAttr(StringRef name) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
+ return *inherentAttr;
+ }
+ return attrs.get(name);
+ }
template <typename AttrClass>
AttrClass getAttrOfType(StringAttr name) {
@@ -444,8 +518,20 @@ class alignas(8) Operation final
/// Return true if the operation has an attribute with the provided name,
/// false otherwise.
- bool hasAttr(StringAttr name) { return attrs.contains(name); }
- bool hasAttr(StringRef name) { return attrs.contains(name); }
+ bool hasAttr(StringAttr name) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
+ return (bool)*inherentAttr;
+ }
+ return attrs.contains(name);
+ }
+ bool hasAttr(StringRef name) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
+ return (bool)*inherentAttr;
+ }
+ return attrs.contains(name);
+ }
template <typename AttrClass, typename NameT>
bool hasAttrOfType(NameT &&name) {
return static_cast<bool>(
@@ -455,6 +541,12 @@ class alignas(8) Operation final
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(StringAttr name, Attribute value) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) {
+ setInherentAttr(name, value);
+ return;
+ }
+ }
NamedAttrList attributes(attrs);
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
@@ -467,6 +559,12 @@ class alignas(8) Operation final
/// attribute that was erased, or nullptr if there was no attribute with such
/// name.
Attribute removeAttr(StringAttr name) {
+ if (getPropertiesStorageSize()) {
+ if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) {
+ setInherentAttr(name, {});
+ return *inherentAttr;
+ }
+ }
NamedAttrList attributes(attrs);
Attribute removedAttr = attributes.erase(name);
if (removedAttr)
@@ -511,7 +609,7 @@ class alignas(8) Operation final
return dialect_attr_iterator(attrs.end(), attrs.end());
}
- /// Set the dialect attributes for this operation, and preserve all dependent.
+ /// Set the dialect attributes for this operation, and preserve all inherent.
template <typename DialectAttrT>
void setDialectAttrs(DialectAttrT &&dialectAttrs) {
NamedAttrList attrs;
@@ -735,6 +833,44 @@ class alignas(8) Operation final
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
+ /// Returns the properties storage size.
+ int getPropertiesStorageSize() const {
+ return ((int)propertiesStorageSize) * 8;
+ }
+ /// Returns the properties storage.
+ OpaqueProperties getPropertiesStorage() {
+ if (propertiesStorageSize)
+ return {
+ reinterpret_cast<void *>(getTrailingObjects<detail::OpProperties>())};
+ return {nullptr};
+ }
+ OpaqueProperties getPropertiesStorage() const {
+ if (propertiesStorageSize)
+ return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
+ getTrailingObjects<detail::OpProperties>()))};
+ return {nullptr};
+ }
+
+ /// Return the properties converted to an attribute.
+ /// This is expensive, and mostly useful when dealing with unregistered
+ /// operation. Returns an empty attribute if no properties are present.
+ Attribute getPropertiesAsAttribute();
+
+ /// Set the properties from the provided attribute.
+ /// This is an expensive operation that can fail if the attribute is not
+ /// matching the expectations of the properties for this operation. This is
+ /// mostly useful for unregistered operations or used when parsing the
+ /// generic format. An optional diagnostic can be passed in for richer errors.
+ LogicalResult setPropertiesFromAttribute(Attribute attr,
+ InFlightDiagnostic *diagnostic);
+
+ /// Copy properties from an existing other properties object. The two objects
+ /// must be the same type.
+ void copyProperties(OpaqueProperties rhs);
+
+ /// Compute a hash for the op properties (if any).
+ llvm::hash_code hashProperties();
+
private:
//===--------------------------------------------------------------------===//
// Ordering
@@ -758,7 +894,8 @@ class alignas(8) Operation final
private:
Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
- DictionaryAttr attributes, bool hasOperandStorage);
+ int propertiesStorageSize, DictionaryAttr attributes,
+ OpaqueProperties properties, bool hasOperandStorage);
// Operations are deleted through the destroy() member because they are
// allocated with malloc.
@@ -845,13 +982,21 @@ class alignas(8) Operation final
const unsigned numResults;
const unsigned numSuccs;
- const unsigned numRegions : 31;
+ const unsigned numRegions : 23;
/// This bit signals whether this operation has an operand storage or not. The
/// operand storage may be elided for operations that are known to never have
/// operands.
bool hasOperandStorage : 1;
+ /// The size of the storage for properties (if any), divided by 8: since the
+ /// Properties storage will always be rounded up to the next multiple of 8 we
+ /// save some bits here.
+ unsigned char propertiesStorageSize : 8;
+ /// This is the maximum size we support to allocate properties inline with an
+ /// operation: this must match the bitwidth above.
+ static constexpr int64_t propertiesCapacity = 8 * 256;
+
/// This holds the name of the operation.
OperationName name;
@@ -871,8 +1016,9 @@ class alignas(8) Operation final
friend class llvm::ilist_node_with_parent<Operation, Block>;
// This stuff is used by the TrailingObjects template.
- friend llvm::TrailingObjects<Operation, detail::OperandStorage, BlockOperand,
- Region, OpOperand>;
+ friend llvm::TrailingObjects<Operation, detail::OperandStorage,
+ detail::OpProperties, BlockOperand, Region,
+ OpOperand>;
size_t numTrailingObjects(OverloadToken<detail::OperandStorage>) const {
return hasOperandStorage ? 1 : 0;
}
@@ -880,6 +1026,9 @@ class alignas(8) Operation final
return numSuccs;
}
size_t numTrailingObjects(OverloadToken<Region>) const { return numRegions; }
+ size_t numTrailingObjects(OverloadToken<detail::OpProperties>) const {
+ return getPropertiesStorageSize();
+ }
};
inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 024c5d3bb4a04..3631e41c1234d 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -14,8 +14,10 @@
#ifndef MLIR_IR_OPERATIONSUPPORT_H
#define MLIR_IR_OPERATIONSUPPORT_H
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
@@ -24,6 +26,7 @@
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
@@ -37,6 +40,7 @@ namespace mlir {
class Dialect;
class DictionaryAttr;
class ElementsAttr;
+struct EmptyProperties;
class MutableOperandRangeRange;
class NamedAttrList;
class Operation;
@@ -59,6 +63,25 @@ class ValueRange;
template <typename ValueRangeT>
class ValueTypeRange;
+//===----------------------------------------------------------------------===//
+// OpaqueProperties
+//===----------------------------------------------------------------------===//
+
+/// Simple wrapper around a void* in order to express generically how to pass
+/// in op properties through APIs.
+class OpaqueProperties {
+public:
+ OpaqueProperties(void *prop) : properties(prop) {}
+ operator bool() const { return properties != nullptr; }
+ template <typename Dest>
+ Dest as() const {
+ return static_cast<Dest>(const_cast<void *>(properties));
+ }
+
+private:
+ void *properties;
+};
+
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
@@ -98,6 +121,26 @@ class OperationName {
virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
virtual LogicalResult verifyInvariants(Operation *) = 0;
virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
+ /// Implementation for properties
+ virtual std::optional<Attribute> getInherentAttr(Operation *,
+ StringRef name) = 0;
+ virtual void setInherentAttr(Operation *op, StringAttr name,
+ Attribute value) = 0;
+ virtual void populateInherentAttrs(Operation *op, NamedAttrList &attrs) = 0;
+ virtual LogicalResult
+ verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) = 0;
+ virtual int getOpPropertyByteSize() = 0;
+ virtual void initProperties(OperationName opName, OpaqueProperties storage,
+ OpaqueProperties init) = 0;
+ virtual void deleteProperties(OpaqueProperties) = 0;
+ virtual void populateDefaultProperties(OperationName opName,
+ OpaqueProperties properties) = 0;
+ virtual LogicalResult setPropertiesFromAttr(Operation *, Attribute,
+ InFlightDiagnostic *) = 0;
+ virtual Attribute getPropertiesAsAttr(Operation *) = 0;
+ virtual void copyProperties(OpaqueProperties, OpaqueProperties) = 0;
+ virtual llvm::hash_code hashProperties(OpaqueProperties) = 0;
};
public:
@@ -158,6 +201,25 @@ class OperationName {
void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
LogicalResult verifyInvariants(Operation *) final;
LogicalResult verifyRegionInvariants(Operation *) final;
+ /// Implementation for properties
+ std::optional<Attribute> getInherentAttr(Operation *op,
+ StringRef name) final;
+ void setInherentAttr(Operation *op, StringAttr name, Attribute value) final;
+ void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final;
+ LogicalResult
+ verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) final;
+ int getOpPropertyByteSize() final;
+ void initProperties(OperationName opName, OpaqueProperties storage,
+ OpaqueProperties init) final;
+ void deleteProperties(OpaqueProperties) final;
+ void populateDefaultProperties(OperationName opName,
+ OpaqueProperties properties) final;
+ LogicalResult setPropertiesFromAttr(Operation *, Attribute,
+ InFlightDiagnostic *) final;
+ Attribute getPropertiesAsAttr(Operation *) final;
+ void copyProperties(OpaqueProperties, OpaqueProperties) final;
+ llvm::hash_code hashProperties(OpaqueProperties) final;
};
public:
@@ -309,6 +371,68 @@ class OperationName {
return !isRegistered() || hasInterface(interfaceID);
}
+ /// Lookup an inherent attribute by name, this method isn't recommended
+ /// and may be removed in the future.
+ std::optional<Attribute> getInherentAttr(Operation *op,
+ StringRef name) const {
+ return getImpl()->getInherentAttr(op, name);
+ }
+
+ void setInherentAttr(Operation *op, StringAttr name, Attribute value) const {
+ return getImpl()->setInherentAttr(op, name, value);
+ }
+
+ void populateInherentAttrs(Operation *op, NamedAttrList &attrs) const {
+ return getImpl()->populateInherentAttrs(op, attrs);
+ }
+ /// This method exists for backward compatibility purpose when using
+ /// properties to store inherent attributes, it enables validating the
+ /// attributes when parsed from the older generic syntax pre-Properties.
+ LogicalResult
+ verifyInherentAttrs(NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) const {
+ return getImpl()->verifyInherentAttrs(*this, attributes, getDiag);
+ }
+ /// This hooks return the number of bytes to allocate for the op properties.
+ int getOpPropertyByteSize() const {
+ return getImpl()->getOpPropertyByteSize();
+ }
+
+ /// This hooks destroy the op properties.
+ void destroyOpProperties(OpaqueProperties properties) const {
+ getImpl()->deleteProperties(properties);
+ }
+
+ /// Initialize the op properties.
+ void initOpProperties(OpaqueProperties storage, OpaqueProperties init) const {
+ getImpl()->initProperties(*this, storage, init);
+ }
+
+ /// Set the default values on the ODS attribute in the properties.
+ void populateDefaultProperties(OpaqueProperties properties) const {
+ getImpl()->populateDefaultProperties(*this, properties);
+ }
+
+ /// Return the op properties converted to an Attribute.
+ Attribute getOpPropertiesAsAttribute(Operation *op) const {
+ return getImpl()->getPropertiesAsAttr(op);
+ }
+
+ /// Define the op properties from the provided Attribute.
+ LogicalResult
+ setOpPropertiesFromAttribute(Operation *op, Attribute properties,
+ InFlightDiagnostic *diagnostic) const {
+ return getImpl()->setPropertiesFromAttr(op, properties, diagnostic);
+ }
+
+ void copyOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const {
+ return getImpl()->copyProperties(lhs, rhs);
+ }
+
+ llvm::hash_code hashOpProperties(OpaqueProperties properties) const {
+ return getImpl()->hashProperties(properties);
+ }
+
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
@@ -413,6 +537,104 @@ class RegisteredOperationName : public OperationName {
LogicalResult verifyRegionInvariants(Operation *op) final {
return ConcreteOp::getVerifyRegionInvariantsFn()(op);
}
+
+ /// Implementation for "Properties"
+
+ using Properties = std::remove_reference_t<
+ decltype(std::declval<ConcreteOp>().getProperties())>;
+
+ std::optional<Attribute> getInherentAttr(Operation *op,
+ StringRef name) final {
+ if constexpr (hasProperties) {
+ auto concreteOp = cast<ConcreteOp>(op);
+ return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name);
+ }
+ // If the op does not have support for properties, we dispatch back to the
+ // dictionnary of discardable attributes for now.
+ return cast<ConcreteOp>(op)->getDiscardableAttr(name);
+ }
+ void setInherentAttr(Operation *op, StringAttr name,
+ Attribute value) final {
+ if constexpr (hasProperties) {
+ auto concreteOp = cast<ConcreteOp>(op);
+ return ConcreteOp::setInherentAttr(concreteOp.getProperties(), name,
+ value);
+ }
+ // If the op does not have support for properties, we dispatch back to the
+ // dictionnary of discardable attributes for now.
+ return cast<ConcreteOp>(op)->setDiscardableAttr(name, value);
+ }
+ void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {
+ if constexpr (hasProperties) {
+ auto concreteOp = cast<ConcreteOp>(op);
+ ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs);
+ }
+ }
+ LogicalResult
+ verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) final {
+ if constexpr (hasProperties)
+ return ConcreteOp::verifyInherentAttrs(opName, attributes, getDiag);
+ return success();
+ }
+ // Detect if the concrete operation defined properties.
+ static constexpr bool hasProperties = !std::is_same_v<
+ typename ConcreteOp::template InferredProperties<ConcreteOp>,
+ EmptyProperties>;
+
+ int getOpPropertyByteSize() final {
+ if constexpr (hasProperties)
+ return sizeof(Properties);
+ return 0;
+ }
+ void initProperties(OperationName opName, OpaqueProperties storage,
+ OpaqueProperties init) final {
+ using Properties =
+ typename ConcreteOp::template InferredProperties<ConcreteOp>;
+ if (init)
+ new (storage.as<Properties *>()) Properties(*init.as<Properties *>());
+ else
+ new (storage.as<Properties *>()) Properties();
+ if constexpr (hasProperties)
+ ConcreteOp::populateDefaultProperties(opName,
+ *storage.as<Properties *>());
+ }
+ void deleteProperties(OpaqueProperties prop) final {
+ prop.as<Properties *>()->~Properties();
+ }
+ void populateDefaultProperties(OperationName opName,
+ OpaqueProperties properties) final {
+ if constexpr (hasProperties)
+ ConcreteOp::populateDefaultProperties(opName,
+ *properties.as<Properties *>());
+ }
+
+ LogicalResult setPropertiesFromAttr(Operation *op, Attribute attr,
+ InFlightDiagnostic *diag) final {
+ if constexpr (hasProperties)
+ return ConcreteOp::setPropertiesFromAttr(
+ cast<ConcreteOp>(op).getProperties(), attr, diag);
+ if (diag)
+ *diag << "This operation does not support properties";
+ return failure();
+ }
+ Attribute getPropertiesAsAttr(Operation *op) final {
+ if constexpr (hasProperties) {
+ auto concreteOp = cast<ConcreteOp>(op);
+ return ConcreteOp::getPropertiesAsAttr(concreteOp->getContext(),
+ concreteOp.getProperties());
+ }
+ return {};
+ }
+ void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {
+ *lhs.as<Properties *>() = *rhs.as<Properties *>();
+ }
+ llvm::hash_code hashProperties(OpaqueProperties prop) final {
+ if constexpr (hasProperties)
+ return ConcreteOp::computePropertiesHash(*prop.as<Properties *>());
+
+ return {};
+ }
};
/// Lookup the registered operation information for the given operation.
@@ -600,6 +822,11 @@ class NamedAttrList {
assign(range.begin(), range.end());
}
+ void clear() {
+ attrs.clear();
+ dictionarySorted.setPointerAndInt(nullptr, false);
+ }
+
bool empty() const { return attrs.empty(); }
void reserve(size_type N) { attrs.reserve(N); }
@@ -694,6 +921,19 @@ struct OperationState {
/// Regions that the op will hold.
SmallVector<std::unique_ptr<Region>, 1> regions;
+ // If we're creating an unregistered operation, this Attribute is used to
+ // build the properties. Otherwise it is ignored. For registered operations
+ // see the `getOrAddProperties` method.
+ Attribute propertiesAttr;
+
+private:
+ OpaqueProperties properties = nullptr;
+ TypeID propertiesId;
+ llvm::function_ref<void(OpaqueProperties)> propertiesDeleter;
+ llvm::function_ref<void(OpaqueProperties, const OpaqueProperties)>
+ propertiesSetter;
+ friend class Operation;
+
public:
OperationState(Location location, StringRef name);
OperationState(Location location, OperationName name);
@@ -706,6 +946,37 @@ struct OperationState {
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
+ OperationState(OperationState &&other) = default;
+ OperationState(const OperationState &other) = default;
+ OperationState &operator=(OperationState &&other) = default;
+ OperationState &operator=(const OperationState &other) = default;
+ ~OperationState();
+
+ /// Get (or create) a properties of the provided type to be set on the
+ /// operation on creation.
+ template <typename T>
+ T &getOrAddProperties() {
+ if (!properties) {
+ T *p = new T{};
+ properties = p;
+ propertiesDeleter = [](OpaqueProperties prop) {
+ delete prop.as<const T *>();
+ };
+ propertiesSetter = [](OpaqueProperties new_prop,
+ const OpaqueProperties prop) {
+ *new_prop.as<T *>() = *prop.as<const T *>();
+ };
+ propertiesId = TypeID::get<T>();
+ }
+ assert(propertiesId == TypeID::get<T>() && "Inconsistent properties");
+ return *properties.as<T *>();
+ }
+ OpaqueProperties getRawProperties() { return properties; }
+
+ // Set the properties defined on this OpState on the given operation,
+ // optionally emit diagnostics on error through the provided diagnostic.
+ LogicalResult setProperties(Operation *op,
+ InFlightDiagnostic *diagnostic) const;
void addOperands(ValueRange newOperands);
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index b63d8b6be6739..acfe40484cf63 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -244,11 +244,11 @@ LogicalResult inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
- RegionRange regions,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes);
/// Verifies that the inferred result types match the actual result types for
@@ -281,7 +281,7 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
static LogicalResult
inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
- RegionRange regions,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
static_assert(
ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
@@ -291,7 +291,7 @@ class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
"requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
- attributes, regions, inferredReturnTypes);
+ attributes, properties, regions, inferredReturnTypes);
}
};
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 9f7118172389f..6925e39b0a9c6 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -44,6 +44,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
+ "::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
>,
@@ -75,13 +76,14 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
+ "::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
llvm::SmallVector<Type, 4> inferredReturnTypes;
if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
- attributes, regions,
+ attributes, properties, regions,
inferredReturnTypes)))
return failure();
if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
@@ -147,6 +149,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueShapeRange":$operands,
"::mlir::DictionaryAttr":$attributes,
+ "::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
$inferredReturnShapes),
diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h
index 5a8f7882c00d6..a2c129b4f28d8 100644
--- a/mlir/include/mlir/TableGen/Argument.h
+++ b/mlir/include/mlir/TableGen/Argument.h
@@ -22,6 +22,7 @@
#define MLIR_TABLEGEN_ARGUMENT_H_
#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include <string>
@@ -58,8 +59,9 @@ struct NamedTypeConstraint {
TypeConstraint constraint;
};
-// Operation argument: either attribute or operand
-using Argument = llvm::PointerUnion<NamedAttribute *, NamedTypeConstraint *>;
+// Operation argument: either attribute, property, or operand
+using Argument = llvm::PointerUnion<NamedAttribute *, NamedProperty *,
+ NamedTypeConstraint *>;
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 10dfc01961a0f..27b5d5c57e913 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -576,7 +576,12 @@ class ExtraClassDeclaration
public:
/// Create an extra class declaration.
ExtraClassDeclaration(StringRef extraClassDeclaration,
- StringRef extraClassDefinition = "")
+ std::string extraClassDefinition = "")
+ : ExtraClassDeclaration(extraClassDeclaration.str(),
+ std::move(extraClassDefinition)) {}
+
+ ExtraClassDeclaration(std::string extraClassDeclaration,
+ std::string extraClassDefinition = "")
: extraClassDeclaration(extraClassDeclaration),
extraClassDefinition(extraClassDefinition) {}
@@ -590,7 +595,7 @@ class ExtraClassDeclaration
private:
/// The string of the extra class declarations. It is re-indented before
/// printed.
- StringRef extraClassDeclaration;
+ std::string extraClassDeclaration;
/// The string of the extra class definitions. It is re-indented before
/// printed.
std::string extraClassDefinition;
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index d85342c742e2e..5337bd3beb5f9 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -86,6 +86,10 @@ class Dialect {
/// operations or types.
bool isExtensible() const;
+ /// Default to use properties for storing Attributes for operations in this
+ /// dialect.
+ bool usePropertiesForAttributes() const;
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 88dff220d6378..c0d483addeb48 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -18,6 +18,7 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Successor.h"
#include "mlir/TableGen/Trait.h"
@@ -166,10 +167,14 @@ class Operator {
unsigned getNumVariableLengthResults() const;
/// Op attribute iterators.
- using attribute_iterator = const NamedAttribute *;
- attribute_iterator attribute_begin() const;
- attribute_iterator attribute_end() const;
- llvm::iterator_range<attribute_iterator> getAttributes() const;
+ using const_attribute_iterator = const NamedAttribute *;
+ const_attribute_iterator attribute_begin() const;
+ const_attribute_iterator attribute_end() const;
+ llvm::iterator_range<const_attribute_iterator> getAttributes() const;
+ using attribute_iterator = NamedAttribute *;
+ attribute_iterator attribute_begin();
+ attribute_iterator attribute_end();
+ llvm::iterator_range<attribute_iterator> getAttributes();
int getNumAttributes() const { return attributes.size(); }
int getNumNativeAttributes() const { return numNativeAttributes; }
@@ -185,6 +190,27 @@ class Operator {
const_value_iterator operand_end() const;
const_value_range getOperands() const;
+ // Op properties iterators.
+ using const_property_iterator = const NamedProperty *;
+ const_property_iterator properties_begin() const {
+ return properties.begin();
+ }
+ const_property_iterator properties_end() const { return properties.end(); }
+ llvm::iterator_range<const_property_iterator> getProperties() const {
+ return properties;
+ }
+ using property_iterator = NamedProperty *;
+ property_iterator properties_begin() { return properties.begin(); }
+ property_iterator properties_end() { return properties.end(); }
+ llvm::iterator_range<property_iterator> getProperties() { return properties; }
+ int getNumCoreAttributes() const { return properties.size(); }
+
+ // Op properties accessors.
+ NamedProperty &getProperty(int index) { return properties[index]; }
+ const NamedProperty &getProperty(int index) const {
+ return properties[index];
+ }
+
int getNumOperands() const { return operands.size(); }
NamedTypeConstraint &getOperand(int index) { return operands[index]; }
const NamedTypeConstraint &getOperand(int index) const {
@@ -353,6 +379,9 @@ class Operator {
/// computed upon request).
SmallVector<NamedAttribute, 4> attributes;
+ /// The properties of the op.
+ SmallVector<NamedProperty> properties;
+
/// The arguments of the op (operands and native attributes).
SmallVector<Argument, 4> arguments;
diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h
new file mode 100644
index 0000000000000..7118ca496de07
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Property.h
@@ -0,0 +1,86 @@
+//===- Property.h - Property wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Property wrapper to simplify using TableGen Record defining a MLIR
+// Property.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_PROPERTY_H_
+#define MLIR_TABLEGEN_PROPERTY_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Constraint.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+class Dialect;
+class Type;
+
+// Wrapper class providing helper methods for accessing MLIR Property defined
+// in TableGen. This class should closely reflect what is defined as class
+// `Property` in TableGen.
+class Property {
+public:
+ explicit Property(const llvm::Record *record);
+ explicit Property(const llvm::DefInit *init);
+
+ // Returns the storage type.
+ StringRef getStorageType() const;
+
+ // Returns the interface type for this property.
+ StringRef getInterfaceType() const;
+
+ // Returns the template getter method call which reads this property's
+ // storage and returns the value as of the desired return type.
+ StringRef getConvertFromStorageCall() const;
+
+ // Returns the template setter method call which reads this property's
+ // in the provided interface type and assign it to the storage.
+ StringRef getAssignToStorageCall() const;
+
+ // Returns the conversion method call which reads this property's
+ // in the storage type and builds an attribute.
+ StringRef getConvertToAttributeCall() const;
+
+ // Returns the setter method call which reads this property's
+ // in the provided interface type and assign it to the storage.
+ StringRef getConvertFromAttributeCall() const;
+
+ // Returns the code to compute the hash for this property.
+ StringRef getHashPropertyCall() const;
+
+ // Returns whether this Property has a default value.
+ bool hasDefaultValue() const;
+ // Returns the default value for this Property.
+ StringRef getDefaultValue() const;
+
+ // Returns the TableGen definition this Property was constructed from.
+ const llvm::Record &getDef() const;
+
+private:
+ // The TableGen definition of this constraint.
+ const llvm::Record *def;
+};
+
+// A struct wrapping an op property and its name together
+struct NamedProperty {
+ llvm::StringRef name;
+ Property prop;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_PROPERTY_H_
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2821686655c2f..c31ac29a4cd84 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -526,9 +526,14 @@ class OpConversionPattern : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getAttrDictionary()),
- rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ if constexpr (SourceOp::hasProperties())
+ return matchAndRewrite(sourceOp,
+ OpAdaptor(operands, op->getAttrDictionary(),
+ sourceOp.getProperties()),
+ rewriter);
+ return matchAndRewrite(
+ sourceOp, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index 80e61e7a817ca..47bd276eed61f 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -225,13 +225,16 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
public:
using RangeT = ArrayRef<ValueRange>;
using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
+ using Properties = typename SourceOp::template InferredProperties<SourceOp>;
OpAdaptor(const OneToNTypeMapping *operandMapping,
const OneToNTypeMapping *resultMapping,
const ValueRange *convertedOperands, RangeT values,
- DictionaryAttr attrs = nullptr, RegionRange regions = {})
- : BaseT(values, attrs, regions), operandMapping(operandMapping),
- resultMapping(resultMapping), convertedOperands(convertedOperands) {}
+ DictionaryAttr attrs = nullptr, Properties &properties = {},
+ RegionRange regions = {})
+ : BaseT(values, attrs, properties, regions),
+ operandMapping(operandMapping), resultMapping(resultMapping),
+ convertedOperands(convertedOperands) {}
/// Get the type mapping of the original operands to the converted operands.
const OneToNTypeMapping &getOperandMapping() const {
@@ -271,7 +274,8 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
valueRanges.push_back(values);
}
OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
- valueRanges, op->getAttrDictionary(), op->getRegions());
+ valueRanges, op->getAttrDictionary(),
+ cast<SourceOp>(op).getProperties(), op->getRegions());
// Call overload implemented by the derived class.
return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 46c1ae976d9d4..3186b0aa420b9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -540,6 +540,7 @@ class OperationParser : public Parser {
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
+ std::optional<Attribute> propertiesAttribute = std::nullopt,
std::optional<FunctionType> parsedFnType = std::nullopt);
/// Parse an operation instance that is in the generic form and insert it at
@@ -1075,7 +1076,8 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
auto name = OperationName("builtin.unrealized_conversion_cast", getContext());
auto *op = Operation::create(
getEncodedSourceLocation(loc), name, type, /*operands=*/{},
- /*attributes=*/std::nullopt, /*successors=*/{}, /*numRegions=*/0);
+ /*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
+ /*numRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
return op->getResult(0);
}
@@ -1255,6 +1257,7 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
std::optional<ArrayRef<Block *>> parsedSuccessors,
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
+ std::optional<Attribute> propertiesAttribute,
std::optional<FunctionType> parsedFnType) {
// Parse the operand list, if not explicitly provided.
@@ -1284,6 +1287,16 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
result.addSuccessors(*parsedSuccessors);
}
+ // Parse the properties, if not explicitly provided.
+ if (propertiesAttribute) {
+ result.propertiesAttr = *propertiesAttribute;
+ } else if (consumeIf(Token::less)) {
+ result.propertiesAttr = parseAttribute();
+ if (!result.propertiesAttr)
+ return failure();
+ if (parseToken(Token::greater, "expected '>' to close properties"))
+ return failure();
+ }
// Parse the region list, if not explicitly provided.
if (!parsedRegions) {
if (consumeIf(Token::l_paren)) {
@@ -1390,10 +1403,52 @@ Operation *OperationParser::parseGenericOperation() {
if (parseGenericOperationAfterOpName(result))
return nullptr;
+ // Operation::create() is not allowed to fail, however setting the properties
+ // from an attribute is a failable operation. So we save the attribute here
+ // and set it on the operation post-parsing.
+ Attribute properties;
+ std::swap(properties, result.propertiesAttr);
+
+ // If we don't have properties in the textual IR, but the operation now has
+ // support for properties, we support some backward-compatible generic syntax
+ // for the operation and as such we accept inherent attributes mixed in the
+ // dictionary of discardable attributes. We pre-validate these here because
+ // invalid attributes can't be casted to the properties storage and will be
+ // silently dropped. For example an attribute { foo = 0 : i32 } that is
+ // declared as F32Attr in ODS would have a C++ type of FloatAttr in the
+ // properties array. When setting it we would do something like:
+ //
+ // properties.foo = dyn_cast<FloatAttr>(fooAttr);
+ //
+ // which would end up with a null Attribute. The diagnostic from the verifier
+ // would be "missing foo attribute" instead of something like "expects a 32
+ // bits float attribute but got a 32 bits integer attribute".
+ if (!properties && !result.getRawProperties()) {
+ Optional<RegisteredOperationName> info = result.name.getRegisteredInfo();
+ if (info) {
+ if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
+ return mlir::emitError(srcLocation) << "'" << name << "' op ";
+ })))
+ return nullptr;
+ }
+ }
+
// Create the operation and try to parse a location for it.
Operation *op = opBuilder.create(result);
if (parseTrailingLocationSpecifier(op))
return nullptr;
+
+ // Try setting the properties for the operation, using a diagnostic to print
+ // errors.
+ if (properties) {
+ InFlightDiagnostic diagnostic =
+ mlir::emitError(srcLocation, "invalid properties ")
+ << properties << " for op " << name << ": ";
+ if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
+ return nullptr;
+ diagnostic.abandon();
+ }
+
return op;
}
@@ -1461,10 +1516,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
std::optional<ArrayRef<Block *>> parsedSuccessors,
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
+ std::optional<Attribute> parsedPropertiesAttribute,
std::optional<FunctionType> parsedFnType) final {
return parser.parseGenericOperationAfterOpName(
result, parsedUnresolvedOperands, parsedSuccessors, parsedRegions,
- parsedAttributes, parsedFnType);
+ parsedAttributes, parsedPropertiesAttribute, parsedFnType);
}
//===--------------------------------------------------------------------===//
// Utilities
@@ -1933,10 +1989,23 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
if (opAsmParser.didEmitError())
return nullptr;
+ Attribute properties = opState.propertiesAttr;
+ opState.propertiesAttr = Attribute{};
+
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
if (parseTrailingLocationSpecifier(op))
return nullptr;
+
+ // Try setting the properties for the operation.
+ if (properties) {
+ InFlightDiagnostic diagnostic =
+ mlir::emitError(srcLocation, "invalid properties ")
+ << properties << " for op " << op->getName().getStringRef() << ": ";
+ if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
+ return nullptr;
+ diagnostic.abandon();
+ }
return op;
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0069bf10263a2..6ed32e1ce8656 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -340,7 +340,8 @@ static LogicalResult inferOperationTypes(OperationState &state) {
if (succeeded(inferInterface->inferReturnTypes(
context, state.location, state.operands,
- state.attributes.getDictionary(context), state.regions, state.types)))
+ state.attributes.getDictionary(context), state.getRawProperties(),
+ state.regions, state.types)))
return success();
// Diagnostic emitted by interface.
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index cb3cebef79730..0890bf2677626 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -113,6 +113,7 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
resultTypes.push_back(tokenType);
auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
op->getSuccessors(), op->getNumRegions());
// Clone regions into new op.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a828fb6a7a679..6b0d2b5d014d5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -1354,9 +1355,10 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
/// shape of the source.
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions);
+ ExtractStridedMetadataOpAdaptor extractAdaptor(
+ operands, attributes, *properties.as<EmptyProperties *>(), regions);
auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
if (!sourceType)
return failure();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index fcddb1ea14120..ea372bffbc0b7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -833,7 +833,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder
SmallVector<Type> inferredReturnTypes;
if (failed(extractStridedMetadataOp.inferReturnTypes(
rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
- /*attributes=*/{}, /*regions=*/{}, inferredReturnTypes)))
+ /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
+ inferredReturnTypes)))
return rewriter.notifyMatchFailure(
reinterpretCastOp, "reinterpret_cast source's type is incompatible");
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 76b3589b15f49..cbd98f39b4068 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1779,7 +1779,7 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
LogicalResult
IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attrs,
- RegionRange regions,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (regions.empty())
return failure();
@@ -1872,7 +1872,8 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
MLIRContext *ctx = builder.getContext();
auto attrDict = DictionaryAttr::get(ctx, result.attributes);
if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
- result.regions, inferredReturnTypes))) {
+ /*properties=*/nullptr, result.regions,
+ inferredReturnTypes))) {
result.addTypes(inferredReturnTypes);
}
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3417388d0bb92..d198b00b9caba 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -393,7 +393,7 @@ void AssumingOp::build(
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -911,7 +911,7 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
@@ -1092,7 +1092,7 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
DimOpAdaptor dimOp(operands);
inferredReturnTypes.assign({dimOp.getIndex().getType()});
@@ -1140,7 +1140,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -1361,7 +1361,7 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
@@ -1399,7 +1399,7 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
return failure();
@@ -1535,7 +1535,7 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ShapeType>())
inferredReturnTypes.assign({SizeType::get(context)});
@@ -1571,7 +1571,7 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ShapeType>())
inferredReturnTypes.assign({SizeType::get(context)});
@@ -1603,7 +1603,7 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
@@ -1635,7 +1635,7 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
@@ -1672,7 +1672,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -1759,7 +1759,7 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ValueShapeType>())
inferredReturnTypes.assign({ShapeType::get(context)});
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2da687b81663c..903200f5f5cd5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -364,7 +364,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
@@ -389,7 +390,8 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
@@ -415,7 +417,8 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
@@ -424,7 +427,8 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
int32_t axis =
@@ -484,7 +488,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outShape;
if (resolveBroadcastShape(operands, outShape).failed()) {
@@ -505,7 +510,8 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor weightShape = operands.getShape(1);
@@ -536,7 +542,8 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor lhsShape = operands.getShape(0);
ShapeAdaptor rhsShape = operands.getShape(1);
@@ -562,7 +569,8 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor paddingShape = operands.getShape(1);
@@ -624,7 +632,8 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(
convertToMlirShape(SliceOpAdaptor(operands, attributes).getSize())));
@@ -633,7 +642,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
@@ -649,7 +659,8 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TileOpAdaptor adaptor(operands, attributes);
ArrayRef<int64_t> multiples = adaptor.getMultiples();
@@ -682,7 +693,8 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ReshapeOpAdaptor adaptor(operands, attributes);
ShapeAdaptor inputShape = operands.getShape(0);
@@ -751,7 +763,8 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor permsShape = operands.getShape(1);
@@ -818,7 +831,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
@@ -843,7 +857,8 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ResizeOpAdaptor adaptor(operands, attributes);
llvm::SmallVector<int64_t, 4> outputShape;
@@ -883,7 +898,8 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
@@ -942,7 +958,7 @@ static LogicalResult ReduceInferReturnTypes(
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
- RegionRange regions, \
+ OpaqueProperties properties, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
Type inputType = \
operands.getType()[0].cast<TensorType>().getElementType(); \
@@ -978,7 +994,7 @@ static LogicalResult NAryInferReturnTypes(
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
- RegionRange regions, \
+ OpaqueProperties properties, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
return NAryInferReturnTypes(operands, inferredReturnShapes); \
}
@@ -1062,7 +1078,8 @@ static LogicalResult poolingInferReturnTypes(
LogicalResult Conv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1125,7 +1142,8 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1198,21 +1216,24 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
}
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
}
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1288,7 +1309,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
// outputShape is mutable.
@@ -1353,7 +1375,8 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (Region *region : regions) {
@@ -1397,7 +1420,8 @@ LogicalResult IfOp::inferReturnTypeComponents(
LogicalResult WhileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (auto &block : *regions[1])
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0cebdd960a8a6..74533defd055f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -39,6 +39,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index bd6069468d7f7..0c03cecf61bc4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -218,9 +218,9 @@ void propagateShapesInRegion(Region ®ion) {
ValueShapeRange range(op.getOperands(), operandShape);
if (shapeInterface
- .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
- op.getAttrDictionary(),
- op.getRegions(), returnedShapes)
+ .inferReturnTypeComponents(
+ op.getContext(), op.getLoc(), range, op.getAttrDictionary(),
+ op.getPropertiesStorage(), op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 439d3780fb8ed..99eed540afb0e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1152,7 +1152,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
- RegionRange,
+ OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes);
auto vectorType = op.getVector().getType().cast<VectorType>();
@@ -2084,7 +2084,7 @@ LogicalResult ShuffleOp::verify() {
LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
- RegionRange,
+ OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes);
auto v1Type = op.getV1().getType().cast<VectorType>();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f0fa86e7e49e2..ceb4fd814a1e2 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,7 +2575,6 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
-
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -3355,6 +3354,10 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
os << ']';
}
+ // Print the properties.
+ if (Attribute prop = op->getPropertiesAsAttribute())
+ os << " <" << prop << '>';
+
// Print regions.
if (op->getNumRegions() != 0) {
os << " (";
@@ -3365,7 +3368,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
os << ')';
}
- auto attrs = op->getAttrs();
+ auto attrs = op->getDiscardableAttrs();
printOptionalAttrDict(attrs);
// Print the type signature of the operation.
@@ -3509,6 +3512,10 @@ void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ValueRange operands) {
+ if (!mapAttr) {
+ os << "<<NULL AFFINE MAP>>";
+ return;
+ }
AffineMap map = mapAttr.getValue();
unsigned numDims = map.getNumDims();
auto printValueName = [&](unsigned pos, bool isSymbol) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index b729282e627d7..c093fa78a9540 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_library(MLIRIR
IntegerSet.cpp
Location.cpp
MLIRContext.cpp
+ ODSSupport.cpp
Operation.cpp
OperationSupport.cpp
PatternMatch.cpp
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e64babf35dac0..5e0f9581e97f6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
@@ -801,6 +802,64 @@ OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
return success();
}
+Optional<Attribute>
+OperationName::UnregisteredOpModel::getInherentAttr(Operation *op,
+ StringRef name) {
+ auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
+ if (!dict)
+ return std::nullopt;
+ if (Attribute attr = dict.get(name))
+ return attr;
+ return std::nullopt;
+}
+void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op,
+ StringAttr name,
+ Attribute value) {
+ auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
+ assert(dict);
+ NamedAttrList attrs(dict);
+ attrs.set(name, value);
+ *op->getPropertiesStorage().as<Attribute *>() =
+ attrs.getDictionary(op->getContext());
+}
+void OperationName::UnregisteredOpModel::populateInherentAttrs(
+ Operation *op, NamedAttrList &attrs) {}
+LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs(
+ OperationName opName, NamedAttrList &attributes,
+ function_ref<InFlightDiagnostic()> getDiag) {
+ return success();
+}
+int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
+ return sizeof(Attribute);
+}
+void OperationName::UnregisteredOpModel::initProperties(
+ OperationName opName, OpaqueProperties storage, OpaqueProperties init) {
+ new (storage.as<Attribute *>()) Attribute();
+}
+void OperationName::UnregisteredOpModel::deleteProperties(
+ OpaqueProperties prop) {
+ prop.as<Attribute *>()->~Attribute();
+}
+void OperationName::UnregisteredOpModel::populateDefaultProperties(
+ OperationName opName, OpaqueProperties properties) {}
+LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr(
+ Operation *op, Attribute attr, InFlightDiagnostic *diag) {
+ *op->getPropertiesStorage().as<Attribute *>() = attr;
+ return success();
+}
+Attribute
+OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) {
+ return *op->getPropertiesStorage().as<Attribute *>();
+}
+void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs,
+ OpaqueProperties rhs) {
+ *lhs.as<Attribute *>() = *rhs.as<Attribute *>();
+}
+llvm::hash_code
+OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
+ return llvm::hash_combine(*prop.as<Attribute *>());
+}
+
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
new file mode 100644
index 0000000000000..ffb84f0f718d3
--- /dev/null
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -0,0 +1,57 @@
+//===- ODSSupport.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains out-of-line implementations of the support types that
+// Operation and related classes build on top of.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ODSSupport.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+
+using namespace mlir;
+
+LogicalResult mlir::convertFromAttribute(int64_t &storage,
+ ::mlir::Attribute attr,
+ ::mlir::InFlightDiagnostic *diag) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ if (diag)
+ *diag << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getSExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 64), storage);
+}
+LogicalResult mlir::convertFromAttribute(MutableArrayRef<int64_t> storage,
+ ::mlir::Attribute attr,
+ ::mlir::InFlightDiagnostic *diag) {
+ auto valueAttr = dyn_cast<DenseI64ArrayAttr>(attr);
+ if (!valueAttr) {
+ if (diag)
+ *diag << "expected DenseI64ArrayAttr for key `value`";
+ return failure();
+ }
+ if (valueAttr.size() != static_cast<int64_t>(storage.size())) {
+ if (diag)
+ *diag << "Size mismatch in attribute conversion: " << valueAttr.size()
+ << " vs " << storage.size();
+ return failure();
+ }
+ llvm::copy(valueAttr.asArrayRef(), storage.begin());
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx,
+ ArrayRef<int64_t> storage) {
+ return DenseI64ArrayAttr::get(ctx, storage);
+}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 5a7594e0f88b7..d9a3abad74ed3 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -7,13 +7,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FoldInterfaces.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include <numeric>
@@ -25,19 +29,31 @@ using namespace mlir;
/// Create a new Operation from operation state.
Operation *Operation::create(const OperationState &state) {
- return create(state.location, state.name, state.types, state.operands,
- state.attributes.getDictionary(state.getContext()),
- state.successors, state.regions);
+ Operation *op =
+ create(state.location, state.name, state.types, state.operands,
+ state.attributes.getDictionary(state.getContext()),
+ state.properties, state.successors, state.regions);
+ if (LLVM_UNLIKELY(state.propertiesAttr)) {
+ assert(!state.properties);
+ LogicalResult result =
+ op->setPropertiesFromAttribute(state.propertiesAttr,
+ /*diagnostic=*/nullptr);
+ assert(result.succeeded() && "invalid properties in op creation");
+ (void)result;
+ }
+ return op;
}
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
- NamedAttrList &&attributes, BlockRange successors,
+ NamedAttrList &&attributes,
+ OpaqueProperties properties, BlockRange successors,
RegionRange regions) {
unsigned numRegions = regions.size();
- Operation *op = create(location, name, resultTypes, operands,
- std::move(attributes), successors, numRegions);
+ Operation *op =
+ create(location, name, resultTypes, operands, std::move(attributes),
+ properties, successors, numRegions);
for (unsigned i = 0; i < numRegions; ++i)
if (regions[i])
op->getRegion(i).takeBody(*regions[i]);
@@ -47,21 +63,23 @@ Operation *Operation::create(Location location, OperationName name,
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
- NamedAttrList &&attributes, BlockRange successors,
+ NamedAttrList &&attributes,
+ OpaqueProperties properties, BlockRange successors,
unsigned numRegions) {
// Populate default attributes.
name.populateDefaultAttrs(attributes);
return create(location, name, resultTypes, operands,
- attributes.getDictionary(location.getContext()), successors,
- numRegions);
+ attributes.getDictionary(location.getContext()), properties,
+ successors, numRegions);
}
/// Overload of create that takes an existing DictionaryAttr to avoid
/// unnecessarily uniquing a list of attributes.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
- DictionaryAttr attributes, BlockRange successors,
+ DictionaryAttr attributes,
+ OpaqueProperties properties, BlockRange successors,
unsigned numRegions) {
assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
"unexpected null result type");
@@ -72,6 +90,7 @@ Operation *Operation::create(Location location, OperationName name,
unsigned numSuccessors = successors.size();
unsigned numOperands = operands.size();
unsigned numResults = resultTypes.size();
+ int opPropertiesAllocSize = name.getOpPropertyByteSize();
// If the operation is known to have no operands, don't allocate an operand
// storage.
@@ -82,8 +101,10 @@ Operation *Operation::create(Location location, OperationName name,
// into account the size of the operation, its trailing objects, and its
// prefixed objects.
size_t byteSize =
- totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>(
- needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands);
+ totalSizeToAlloc<detail::OperandStorage, detail::OpProperties,
+ BlockOperand, Region, OpOperand>(
+ needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors,
+ numRegions, numOperands);
size_t prefixByteSize = llvm::alignTo(
Operation::prefixAllocSize(numTrailingResults, numInlineResults),
alignof(Operation));
@@ -91,9 +112,9 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Create the new Operation.
- Operation *op =
- ::new (rawMem) Operation(location, name, numResults, numSuccessors,
- numRegions, attributes, needsOperandStorage);
+ Operation *op = ::new (rawMem) Operation(
+ location, name, numResults, numSuccessors, numRegions,
+ opPropertiesAllocSize, attributes, properties, needsOperandStorage);
assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&
"unexpected successors in a non-terminator operation");
@@ -122,16 +143,22 @@ Operation *Operation::create(Location location, OperationName name,
for (unsigned i = 0; i != numSuccessors; ++i)
new (&blockOperands[i]) BlockOperand(op, successors[i]);
+ // This must be done after properties are initalized.
+ op->setAttrs(attributes);
+
return op;
}
Operation::Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
- DictionaryAttr attributes, bool hasOperandStorage)
+ int fullPropertiesStorageSize, DictionaryAttr attributes,
+ OpaqueProperties properties, bool hasOperandStorage)
: location(location), numResults(numResults), numSuccs(numSuccessors),
- numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name),
- attrs(attributes) {
+ numRegions(numRegions), hasOperandStorage(hasOperandStorage),
+ propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) {
assert(attributes && "unexpected null attribute dictionary");
+ assert(fullPropertiesStorageSize <= propertiesCapacity &&
+ "Properties size overflow");
#ifndef NDEBUG
if (!getDialect() && !getContext()->allowsUnregisteredDialects())
llvm::report_fatal_error(
@@ -140,6 +167,8 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
"allowUnregisteredDialects() on the MLIRContext, or use "
"-allow-unregistered-dialect with the MLIR tool used.");
#endif
+ if (fullPropertiesStorageSize)
+ name.initOpProperties(getPropertiesStorage(), properties);
}
// Operations are deleted through the destroy() member because they are
@@ -168,6 +197,8 @@ Operation::~Operation() {
// Explicitly destroy the regions.
for (auto ®ion : getRegions())
region.~Region();
+ if (propertiesStorageSize)
+ name.destroyOpProperties(getPropertiesStorage());
}
/// Destroy this operation or one of its subclasses.
@@ -259,6 +290,68 @@ InFlightDiagnostic Operation::emitRemark(const Twine &message) {
return diag;
}
+DictionaryAttr Operation::getAttrDictionary() {
+ if (getPropertiesStorageSize()) {
+ NamedAttrList attrsList = attrs;
+ getName().populateInherentAttrs(this, attrsList);
+ return attrsList.getDictionary(getContext());
+ }
+ return attrs;
+}
+
+void Operation::setAttrs(DictionaryAttr newAttrs) {
+ assert(newAttrs && "expected valid attribute dictionary");
+ if (getPropertiesStorageSize()) {
+ attrs = DictionaryAttr::get(getContext(), {});
+ for (const NamedAttribute &attr : newAttrs)
+ setAttr(attr.getName(), attr.getValue());
+ return;
+ }
+ attrs = newAttrs;
+}
+void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
+ if (getPropertiesStorageSize()) {
+ setAttrs(DictionaryAttr::get(getContext(), {}));
+ for (const NamedAttribute &attr : newAttrs)
+ setAttr(attr.getName(), attr.getValue());
+ return;
+ }
+ attrs = DictionaryAttr::get(getContext(), newAttrs);
+}
+
+std::optional<Attribute> Operation::getInherentAttr(StringRef name) {
+ return getName().getInherentAttr(this, name);
+}
+
+void Operation::setInherentAttr(StringAttr name, Attribute value) {
+ getName().setInherentAttr(this, name, value);
+}
+
+Attribute Operation::getPropertiesAsAttribute() {
+ Optional<RegisteredOperationName> info = getRegisteredInfo();
+ if (LLVM_UNLIKELY(!info))
+ return *getPropertiesStorage().as<Attribute *>();
+ return info->getOpPropertiesAsAttribute(this);
+}
+LogicalResult
+Operation::setPropertiesFromAttribute(Attribute attr,
+ InFlightDiagnostic *diagnostic) {
+ Optional<RegisteredOperationName> info = getRegisteredInfo();
+ if (LLVM_UNLIKELY(!info)) {
+ *getPropertiesStorage().as<Attribute *>() = attr;
+ return success();
+ }
+ return info->setOpPropertiesFromAttribute(this, attr, diagnostic);
+}
+
+void Operation::copyProperties(OpaqueProperties rhs) {
+ name.copyOpProperties(getPropertiesStorage(), rhs);
+}
+
+llvm::hash_code Operation::hashProperties() {
+ return name.hashOpProperties(getPropertiesStorage());
+}
+
//===----------------------------------------------------------------------===//
// Operation Ordering
//===----------------------------------------------------------------------===//
@@ -581,7 +674,7 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
// Create the new operation.
auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
- successors, getNumRegions());
+ getPropertiesStorage(), successors, getNumRegions());
mapper.map(this, newOp);
// Clone the regions.
@@ -636,6 +729,20 @@ void OpState::printOpName(Operation *op, OpAsmPrinter &p,
p.getStream() << name;
}
+/// Parse properties as a Attribute.
+ParseResult OpState::genericParseProperties(OpAsmParser &parser,
+ Attribute &result) {
+ if (parser.parseLess() || parser.parseAttribute(result) ||
+ parser.parseGreater())
+ return failure();
+ return success();
+}
+
+/// Print the properties as a Attribute.
+void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) {
+ p << "<" << properties << ">";
+}
+
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening.
InFlightDiagnostic OpState::emitError(const Twine &message) {
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index f2a2fc2400444..a0a86e776ce37 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -193,6 +193,23 @@ OperationState::OperationState(Location location, StringRef name,
: OperationState(location, OperationName(name, location.getContext()),
operands, types, attributes, successors, regions) {}
+OperationState::~OperationState() {
+ if (properties)
+ propertiesDeleter(properties);
+}
+
+LogicalResult
+OperationState::setProperties(Operation *op,
+ InFlightDiagnostic *diagnostic) const {
+ if (LLVM_UNLIKELY(propertiesAttr)) {
+ assert(!properties);
+ return op->setPropertiesFromAttribute(propertiesAttr, diagnostic);
+ }
+ if (properties)
+ propertiesSetter(op->getPropertiesStorage(), properties);
+ return success();
+}
+
void OperationState::addOperands(ValueRange newOperands) {
operands.append(newOperands.begin(), newOperands.end());
}
@@ -633,8 +650,9 @@ llvm::hash_code OperationEquivalence::computeHash(
// - Operation Name
// - Attributes
// - Result Types
- llvm::hash_code hash = llvm::hash_combine(
- op->getName(), op->getAttrDictionary(), op->getResultTypes());
+ llvm::hash_code hash =
+ llvm::hash_combine(op->getName(), op->getAttrDictionary(),
+ op->getResultTypes(), op->hashProperties());
// - Operands
ValueRange operands = op->getOperands();
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 7d464af78023e..ebb10e07cebd2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -220,15 +220,15 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
- RegionRange regions,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
- if (failed(componentTypeFn(context, location, operands, attributes, regions,
- retComponents)))
+ if (failed(componentTypeFn(context, location, operands, attributes,
+ properties, regions, retComponents)))
return failure();
for (const auto &shapeAndType : retComponents) {
Type elementTy = shapeAndType.getElementType();
@@ -249,7 +249,12 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
auto retTypeFn = cast<InferTypeOpInterface>(op);
- return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
- op->getOperands(), op->getAttrDictionary(),
- op->getRegions(), inferredReturnTypes);
+ auto result = retTypeFn.refineReturnTypes(
+ op->getContext(), op->getLoc(), op->getOperands(),
+ op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(),
+ inferredReturnTypes);
+ if (failed(result))
+ op->emitOpError() << "failed to infer returned types";
+
+ return result;
}
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 1ea4bef5402a7..eca0297733e7d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1613,8 +1613,8 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
// TODO: Handle failure.
if (failed(inferInterface->inferReturnTypes(
state.getContext(), state.location, state.operands,
- state.attributes.getDictionary(state.getContext()), state.regions,
- state.types)))
+ state.attributes.getDictionary(state.getContext()),
+ state.getRawProperties(), state.regions, state.types)))
return;
} else {
// Otherwise, this is a fixed number of results.
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 74c6cdc834abd..2736d95cc9742 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -24,6 +24,7 @@ llvm_add_library(MLIRTableGen STATIC
Pass.cpp
Pattern.cpp
Predicate.cpp
+ Property.cpp
Region.cpp
SideEffects.cpp
Successor.cpp
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index 193e8c1ce374b..afd02b1a64ac9 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -133,13 +133,18 @@ static ::mlir::LogicalResult {0}(
/// functions are stripped anyways.
static const char *const attrConstraintCode = R"(
static ::mlir::LogicalResult {0}(
- ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
- if (attr && !({1})) {
- return op->emitOpError("attribute '") << attrName
+ ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> getDiag) {{
+ if (attr && !({1}))
+ return getDiag() << "attribute '" << attrName
<< "' failed to satisfy constraint: {2}";
- }
return ::mlir::success();
}
+static ::mlir::LogicalResult {0}(
+ ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
+ return {0}(attr, attrName, [op]() {{
+ return op->emitOpError();
+ });
+}
)";
/// Code for a successor constraint.
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index ec2c31dab440c..74f3981d91947 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -103,6 +103,10 @@ bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}
+bool Dialect::usePropertiesForAttributes() const {
+ return def->getValueAsBit("usePropertiesForAttributes");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 291777198eafa..03557d9d283a3 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
@@ -322,14 +323,23 @@ auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
return {trait_begin(), trait_end()};
}
-auto Operator::attribute_begin() const -> attribute_iterator {
+auto Operator::attribute_begin() const -> const_attribute_iterator {
return attributes.begin();
}
-auto Operator::attribute_end() const -> attribute_iterator {
+auto Operator::attribute_end() const -> const_attribute_iterator {
return attributes.end();
}
auto Operator::getAttributes() const
- -> llvm::iterator_range<attribute_iterator> {
+ -> llvm::iterator_range<const_attribute_iterator> {
+ return {attribute_begin(), attribute_end()};
+}
+auto Operator::attribute_begin() -> attribute_iterator {
+ return attributes.begin();
+}
+auto Operator::attribute_end() -> attribute_iterator {
+ return attributes.end();
+}
+auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
@@ -542,6 +552,7 @@ void Operator::populateOpStructure() {
auto &recordKeeper = def.getRecords();
auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto *attrClass = recordKeeper.getClass("Attr");
+ auto *propertyClass = recordKeeper.getClass("Property");
auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
@@ -576,9 +587,14 @@ void Operator::populateOpStructure() {
"derived attributes not allowed in argument list");
attributes.push_back({givenName, Attribute(argDef)});
++numNativeAttributes;
+ } else if (argDef->isSubClassOf(propertyClass)) {
+ if (givenName.empty())
+ PrintFatalError(argDef->getLoc(), "properties must be named");
+ properties.push_back({givenName, Property(argDef)});
} else {
- PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
- "from TypeConstraint or Attr are allowed");
+ PrintFatalError(def.getLoc(),
+ "unexpected def type; only defs deriving "
+ "from TypeConstraint or Attr or Property are allowed");
}
if (!givenName.empty())
argumentsAndResultsIndex[givenName] = i;
@@ -608,7 +624,7 @@ void Operator::populateOpStructure() {
// `attributes` because we will put their elements' pointers in `arguments`.
// SmallVector may perform re-allocation under the hood when adding new
// elements.
- int operandIndex = 0, attrIndex = 0;
+ int operandIndex = 0, attrIndex = 0, propIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
if (argDef->isSubClassOf(opVarClass))
@@ -618,11 +634,13 @@ void Operator::populateOpStructure() {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Operand, operandIndex});
arguments.emplace_back(&operands[operandIndex++]);
- } else {
- assert(argDef->isSubClassOf(attrClass));
+ } else if (argDef->isSubClassOf(attrClass)) {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Attribute, attrIndex});
arguments.emplace_back(&attributes[attrIndex++]);
+ } else {
+ assert(argDef->isSubClassOf(propertyClass));
+ arguments.emplace_back(&properties[propIndex++]);
}
}
diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp
new file mode 100644
index 0000000000000..966c65da16788
--- /dev/null
+++ b/mlir/lib/TableGen/Property.cpp
@@ -0,0 +1,86 @@
+//===- Property.cpp - Property wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Property wrapper to simplify using TableGen Record defining a MLIR
+// Property.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Property.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+using llvm::StringInit;
+
+// Returns the initializer's value as string if the given TableGen initializer
+// is a code or string initializer. Returns the empty StringRef otherwise.
+static StringRef getValueAsString(const Init *init) {
+ if (const auto *str = dyn_cast<StringInit>(init))
+ return str->getValue().trim();
+ return {};
+}
+
+Property::Property(const Record *record) : def(record) {
+ assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) &&
+ "must be subclass of TableGen 'Property' class");
+}
+
+Property::Property(const DefInit *init) : Property(init->getDef()) {}
+
+StringRef Property::getStorageType() const {
+ const auto *init = def->getValueInit("storageType");
+ auto type = getValueAsString(init);
+ if (type.empty())
+ return "Property";
+ return type;
+}
+
+StringRef Property::getInterfaceType() const {
+ const auto *init = def->getValueInit("interfaceType");
+ return getValueAsString(init);
+}
+
+StringRef Property::getConvertFromStorageCall() const {
+ const auto *init = def->getValueInit("convertFromStorage");
+ return getValueAsString(init);
+}
+
+StringRef Property::getAssignToStorageCall() const {
+ const auto *init = def->getValueInit("assignToStorage");
+ return getValueAsString(init);
+}
+
+StringRef Property::getConvertToAttributeCall() const {
+ const auto *init = def->getValueInit("convertToAttribute");
+ return getValueAsString(init);
+}
+
+StringRef Property::getConvertFromAttributeCall() const {
+ const auto *init = def->getValueInit("convertFromAttribute");
+ return getValueAsString(init);
+}
+
+StringRef Property::getHashPropertyCall() const {
+ return getValueAsString(def->getValueInit("hashProperty"));
+}
+
+bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); }
+
+StringRef Property::getDefaultValue() const {
+ const auto *init = def->getValueInit("defaultValue");
+ return getValueAsString(init);
+}
+
+const llvm::Record &Property::getDef() const { return *def; }
diff --git a/mlir/test/Bytecode/versioning/versioned_attr.mlir b/mlir/test/Bytecode/versioning/versioned_attr.mlir
index 98756cc66a99c..0fd6c3c2dfe9e 100644
--- a/mlir/test/Bytecode/versioning/versioned_attr.mlir
+++ b/mlir/test/Bytecode/versioning/versioned_attr.mlir
@@ -11,10 +11,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 1.12
-// COM: "test.versionedB"() {attribute = #test.attr_params<24, 42>} : () -> ()
+// COM: "test.versionedB"() <{attribute = #test.attr_params<24, 42>}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
-// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+// CHECK1: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
//===--------------------------------------------------------------------===//
// Test attribute upgrade
@@ -23,7 +23,7 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.0
-// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+// COM: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
-// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+// CHECK2: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir
index b3141de67823f..5fa170bc24904 100644
--- a/mlir/test/Bytecode/versioning/versioned_op.mlir
+++ b/mlir/test/Bytecode/versioning/versioned_op.mlir
@@ -11,10 +11,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.0
-// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
-// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// CHECK1: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
//===--------------------------------------------------------------------===//
// Test upgrade
@@ -23,10 +23,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 1.12
-// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> ()
+// COM: "test.versionedA"() <{dimensions = 123 : i64}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
-// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// CHECK2: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
//===--------------------------------------------------------------------===//
// Test forbidden downgrade
@@ -35,7 +35,7 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.2
-// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
// COM: }
// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION
// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2
diff --git a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
index be2365215dce7..611ec0265cd37 100644
--- a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
+++ b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
@@ -41,16 +41,16 @@ func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) {
//
// CHECK-TUP-LABEL: func.func @materializations_tuple_args(
// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) {
-// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<i2>) -> i2
-// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
+// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
+// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
+// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
+// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
+// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-TUP-DAG: return %[[V1]], %[[V9]] : i1, i2
// If we only convert the func ops, argument materializations are created from
@@ -64,11 +64,11 @@ func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) {
// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
-// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
+// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
+// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
// If we convert both tuple and func ops, basically everything disappears.
@@ -117,11 +117,11 @@ func.func @materializations_tuple_args(%arg0: tuple<tuple<>, i1, tuple<tuple<i2>
// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
-// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
+// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
+// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
// If we convert both tuple and func ops, basically everything disappears.
diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
index dd2013c9a7368..263711674a6ec 100644
--- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
+++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
@@ -36,13 +36,13 @@ func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<t
// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) {
// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V5]] : i1
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V8]] : i1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V2]] : i1
@@ -94,14 +94,14 @@ func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i
// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<>
// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1):
// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
+// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
+// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V8]] : i1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V0]] : i1
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index daed6a49a0e82..ce2a438d70311 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -98,6 +98,7 @@ func.func @shape_of(%value_arg : !shape.value_shape,
// -----
func.func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
return
@@ -268,6 +269,7 @@ func.func @fn(%arg: !shape.shape) -> !shape.witness {
// Test that type inference flags the wrong return type.
func.func @const_shape() {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}}
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
return
@@ -276,6 +278,7 @@ func.func @const_shape() {
// -----
func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{requires all sizes or shapes}}
%result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
return %result : index
@@ -284,6 +287,7 @@ func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
// -----
func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{unequal shape cardinality}}
%result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return %result : tensor<?xindex>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 27661f4c57847..edb4bb0a873ec 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -39,6 +39,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{Cannot concat tensors with
diff erent sizes on the non-axis dimension 1}}
%0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
@@ -47,6 +48,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
%0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
@@ -100,6 +102,7 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
// -----
func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
return
@@ -108,6 +111,7 @@ func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
%0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
return
@@ -116,6 +120,7 @@ func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
%0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
return
@@ -124,6 +129,7 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
%0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
@@ -132,6 +138,7 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
return
diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
index 6f4923a9f4f75..a4fb56d584a34 100644
--- a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
+++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
@@ -7,7 +7,7 @@ func.func @add_to_worklist_after_inplace_update() {
// worklist of the GreedyPatternRewriteDriver (regardless of the value of
// config.max_iterations).
- // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
+ // CHECK: "test.any_attr_of_i32_str"() <{attr = 3 : i32}> : () -> ()
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
return
}
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 9bf7efd8ae5d8..861f4ef6c020d 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -587,11 +587,6 @@ func.func @bad_arrow(%arg : !unreg.ptr<(i32)->)
// -----
-// expected-error @+1 {{attribute 'attr' occurs more than once in the attribute list}}
-test.format_symbol_name_attr_op @name { attr = "xx" }
-
-// -----
-
func.func @forward_reference_type_check() -> (i8) {
cf.br ^bb2
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 75ebc1a7536c3..2041ca90e8116 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1437,3 +1437,4 @@ test.dialect_custom_format_fallback custom_format_fallback
// Check that an op with an optional result parses f80 as type.
// CHECK: test.format_optional_result_d_op : f80
test.format_optional_result_d_op : f80
+
diff --git a/mlir/test/IR/properties.mlir b/mlir/test/IR/properties.mlir
new file mode 100644
index 0000000000000..b073698b40aff
--- /dev/null
+++ b/mlir/test/IR/properties.mlir
@@ -0,0 +1,20 @@
+// # RUN: mlir-opt %s -split-input-file | mlir-opt |FileCheck %s
+// # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file | mlir-opt -mlir-print-op-generic | FileCheck %s --check-prefix=GENERIC
+
+// CHECK: test.with_properties
+// CHECK-SAME: <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
+// GENERIC: "test.with_properties"()
+// GENERIC-SAME: <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}> : () -> ()
+test.with_properties <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
+
+// CHECK: test.with_nice_properties
+// CHECK-SAME: "foo bar" is -3
+// GENERIC: "test.with_nice_properties"()
+// GENERIC-SAME: <{prop = {label = "foo bar", value = -3 : i32}}> : () -> ()
+test.with_nice_properties "foo bar" is -3
+
+// CHECK: test.with_wrapped_properties
+// CHECK-SAME: "content for properties"
+// GENERIC: "test.with_wrapped_properties"()
+// GENERIC-SAME: <{prop = "content for properties"}> : () -> ()
+test.with_wrapped_properties <{prop = "content for properties"}>
diff --git a/mlir/test/IR/test-fold-adaptor.mlir b/mlir/test/IR/test-fold-adaptor.mlir
index 7815e729f55ad..a7dec9bb871a9 100644
--- a/mlir/test/IR/test-fold-adaptor.mlir
+++ b/mlir/test/IR/test-fold-adaptor.mlir
@@ -12,5 +12,5 @@ func.func @test() -> i32 {
}
// CHECK-LABEL: func.func @test
-// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() <{value = 33 : i32}>
// CHECK-NEXT: return %[[C]]
diff --git a/mlir/test/IR/test-manual-cpp-fold.mlir b/mlir/test/IR/test-manual-cpp-fold.mlir
index 592b949f0a139..7364b7ccb5762 100644
--- a/mlir/test/IR/test-manual-cpp-fold.mlir
+++ b/mlir/test/IR/test-manual-cpp-fold.mlir
@@ -7,5 +7,5 @@ func.func @test() -> i32 {
}
// CHECK-LABEL: func.func @test
-// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32}
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() <{value = 5 : i32}>
// CHECK-NEXT: return %[[C]]
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index ddba1171649c9..307918defab0e 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -345,8 +345,8 @@ func.func @failedSingleBlockImplicitTerminator_missing_terminator() {
// -----
-// expected-error at +1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}}
-"test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
+// expected-error at +1 {{invalid properties {sym_name = "foo_2", sym_visibility} for op test.symbol: Invalid attribute `sym_visibility` in property conversion: unit}}
+"test.symbol"() <{sym_name = "foo_2", sym_visibility}> : () -> ()
// -----
@@ -390,14 +390,14 @@ func.func @failedMissingOperandSizeAttr(%arg: i32) {
// -----
func.func @failedOperandSizeAttrWrongType(%arg: i32) {
- // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> ()
}
// -----
func.func @failedOperandSizeAttrWrongElementType(%arg: i32) {
- // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array<i64: 1, 1, 1, 1>} : (i32, i32, i32, i32) -> ()
}
@@ -655,7 +655,7 @@ func.func @failed_type_traits() {
// Check that we can query traits in attributes
func.func @succeeded_attr_traits() {
- // CHECK: "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+ // CHECK: "test.attr_with_trait"() <{attr = #test.attr_with_trait}> : () -> ()
"test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
return
}
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 81c1531484d43..c74af447d1b1f 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
// CHECK-LABEL: func @constant
-// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
+// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
// CHECK: return %[[cst]]
func.func @constant() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
@@ -10,7 +10,7 @@ func.func @constant() -> index {
}
// CHECK-LABEL: func @increment
-// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
+// CHECK: %[[cst:.*]] = "test.constant"() <{value = 4 : index}
// CHECK: return %[[cst]]
func.func @increment() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
@@ -103,8 +103,8 @@ func.func @func_args_unbound(%arg0 : index) -> index {
// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
+ // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
+ // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
%1 = scf.while : () -> index {
@@ -122,8 +122,8 @@ func.func @propagate_across_while_loop_false() -> index {
// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
+ // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
+ // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
%1 = scf.while : () -> index {
@@ -140,7 +140,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
// CHECK-LABEL: func @dont_propagate_across_infinite_loop()
func.func @dont_propagate_across_infinite_loop() -> index {
- // CHECK: %[[C0:.*]] = "test.constant"() {value = 0
+ // CHECK: %[[C0:.*]] = "test.constant"() <{value = 0
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
// CHECK: %[[loopRes:.*]] = scf.while
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index 51b63ba4c0ad9..b8fad63eb4de6 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -10,8 +10,8 @@
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
+// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @identity(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -61,12 +61,12 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
-// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
-// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 1 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
-// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<i1>) -> i1
-// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 2 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
+// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
+// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
+// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK: return %[[V7]], %[[V10]] : i1, i2
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -88,12 +88,12 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
+// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
+// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -131,13 +131,13 @@ func.func @caller(%arg0: tuple<>) -> tuple<> {
// CHECK-LABEL: func @unconverted_op_result() -> (i1, i32) {
// CHECK: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
+// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) {
// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
-// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
+// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32
func.func @unconverted_op_result() -> tuple<i1, i32> {
%0 = "test.source"() : () -> (tuple<i1, i32>)
@@ -155,9 +155,9 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
// CHECK: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
-// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
-// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
-// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
+// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1
+// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
+// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
// CHECK: return %[[V3]], %[[V5]] : i1, i32
// CHECK-12N-LABEL: func @nested_unconverted_op_result(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -165,9 +165,9 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
-// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
-// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
-// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
+// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1
+// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
+// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32
func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
%0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
@@ -191,12 +191,12 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I5:.*]]: i5,
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
-// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
+// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
+// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
-// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
+// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
+// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[I1:.*]]: i1,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 06bd5b6ae4984..6897b6f95f0d0 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: verifyDirectPattern
func.func @verifyDirectPattern() -> i32 {
- // CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
+ // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
%result = "test.illegal_op_a"() : () -> (i32)
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return %result : i32
@@ -10,7 +10,7 @@ func.func @verifyDirectPattern() -> i32 {
// CHECK-LABEL: verifyLargerBenefit
func.func @verifyLargerBenefit() -> i32 {
- // CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
+ // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
%result = "test.illegal_op_c"() : () -> (i32)
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return %result : i32
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 670ec232a3922..46ee07af993cc 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -7,7 +7,7 @@ func.func @foo() -> i32 {
// The new operation should be present in the output and contain an attribute
// with value "42" that results from folding.
- // CHECK: "test.op_in_place_fold"(%{{.*}}) {attr = 42 : i32}
+ // CHECK: "test.op_in_place_fold"(%{{.*}}) <{attr = 42 : i32}
%0 = "test.op_in_place_fold_anchor"(%c42) : (i32) -> (i32)
return %0 : i32
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 40de4374c6a35..83226a5f3c245 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -43,6 +44,36 @@
using namespace mlir;
using namespace test;
+Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
+ return StringAttr::get(ctx, content);
+}
+LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
+ InFlightDiagnostic *diag) {
+ StringAttr strAttr = attr.dyn_cast<StringAttr>();
+ if (!strAttr) {
+ if (diag)
+ *diag << "Expect StringAttr but got " << attr;
+ return failure();
+ }
+ prop.content = strAttr.getValue();
+ return success();
+}
+llvm::hash_code MyPropStruct::hash() const {
+ return hash_value(StringRef(content));
+}
+
+static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
+ Attribute attr,
+ InFlightDiagnostic *diagnostic);
+static DictionaryAttr
+getPropertiesAsAttribute(MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop);
+static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
+static void customPrintProperties(OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop);
+static ParseResult customParseProperties(OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop);
+
void test::registerTestDialect(DialectRegistry ®istry) {
registry.insert<TestDialect>();
}
@@ -514,7 +545,7 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
- ::mlir::RegionRange regions,
+ OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
@@ -1264,7 +1295,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
}
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
- if (adaptor.getOp() && !(*this)->hasAttr("attr")) {
+ if (adaptor.getOp() && !(*this)->getAttr("attr")) {
// The folder adds "attr" if not present.
(*this)->setAttr("attr", adaptor.getOp());
return getResult();
@@ -1297,7 +1328,7 @@ OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
@@ -1312,16 +1343,17 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
returnTypes.clear();
return OpWithRefineTypeInterfaceOp::refineReturnTypes(
- context, location, operands, attributes, regions, returnTypes);
+ context, location, operands, attributes, properties, regions,
+ returnTypes);
}
LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, RegionRange regions,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
@@ -1340,7 +1372,8 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
@@ -1797,6 +1830,59 @@ OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
return nullptr;
}
+static LogicalResult
+setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
+ InFlightDiagnostic *diagnostic) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ if (diagnostic)
+ *diagnostic << "expected DictionaryAttr to set TestProperties";
+ return failure();
+ }
+ auto label = dict.getAs<mlir::StringAttr>("label");
+ if (!label) {
+ if (diagnostic)
+ *diagnostic << "expected StringAttr for key `label`";
+ return failure();
+ }
+ auto valueAttr = dict.getAs<IntegerAttr>("value");
+ if (!valueAttr) {
+ if (diagnostic)
+ *diagnostic << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+
+ prop.label = std::make_shared<std::string>(label.getValue());
+ prop.value = valueAttr.getValue().getSExtValue();
+ return success();
+}
+static DictionaryAttr
+getPropertiesAsAttribute(MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
+ attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
+ return b.getDictionaryAttr(attrs);
+}
+static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
+ return llvm::hash_combine(prop.value, StringRef(*prop.label));
+}
+static void customPrintProperties(OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop) {
+ p.printKeywordOrString(*prop.label);
+ p << " is " << prop.value;
+}
+static ParseResult customParseProperties(OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop) {
+ std::string label;
+ if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
+ parser.parseInteger(prop.value))
+ return failure();
+ prop.label = std::make_shared<std::string>(std::move(label));
+ return success();
+}
+
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestTypeInterfaces.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index ad3ef2a9f1cdd..dcca76d7dc388 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -42,6 +42,8 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include <memory>
+
namespace mlir {
class DLTIDialect;
class RewritePatternSet;
@@ -54,6 +56,30 @@ class RewritePatternSet;
#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"
+namespace test {
+// Define some classes to exercises the Properties feature.
+
+struct PropertiesWithCustomPrint {
+ /// A shared_ptr to a const object is safe: it is equivalent to a value-based
+ /// member. Here the label will be deallocated when the last operation
+ /// refering to it is destroyed. However there is no pool-allocation: this is
+ /// offloaded to the client.
+ std::shared_ptr<const std::string> label;
+ int value;
+};
+class MyPropStruct {
+public:
+ std::string content;
+ // These three methods are invoked through the `MyStructProperty` wrapper
+ // defined in TestOps.td
+ mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
+ static mlir::LogicalResult setFromAttr(MyPropStruct &prop,
+ mlir::Attribute attr,
+ mlir::InFlightDiagnostic *diag);
+ llvm::hash_code hash() const;
+};
+} // namespace test
+
#define GET_OP_CLASSES
#include "TestOps.h.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 9ec12749bfcc3..01116a29367b2 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -24,6 +24,7 @@ def Test_Dialect : Dialect {
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
+ let usePropertiesForAttributes = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 887031e6ff434..29017581b1f68 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -430,7 +430,7 @@ def VariadicRegionInferredTypesOp : TEST_Op<"variadic_region_inferred",
let extraClassDeclaration = [{
static mlir::LogicalResult inferReturnTypes(mlir::MLIRContext *context,
std::optional<::mlir::Location> location, mlir::ValueRange operands,
- mlir::DictionaryAttr attributes, mlir::RegionRange regions,
+ mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, mlir::RegionRange regions,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({mlir::IntegerType::get(context, 16)});
return mlir::success();
@@ -2524,7 +2524,7 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
@@ -2547,7 +2547,7 @@ class FormatInferAllTypesBaseOp<string mnemonic, list<Trait> traits = []>
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
::mlir::TypeRange operandTypes = operands.getTypes();
inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
@@ -2594,7 +2594,7 @@ def FormatInferTypeRegionsOp
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
if (regions.empty())
return ::mlir::failure();
@@ -2615,9 +2615,10 @@ def FormatInferTypeVariadicOperandsOp
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
- FormatInferTypeVariadicOperandsOpAdaptor adaptor(operands, attributes);
+ FormatInferTypeVariadicOperandsOpAdaptor adaptor(
+ operands, attributes, *properties.as<Properties *>(), {});
auto aTypes = adaptor.getA().getTypes();
auto bTypes = adaptor.getB().getTypes();
inferredReturnTypes.append(aTypes.begin(), aTypes.end());
@@ -2823,7 +2824,7 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
- ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
return ::mlir::success();
@@ -3280,4 +3281,77 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
);
}
+//===----------------------------------------------------------------------===//
+// Test Properties
+//===----------------------------------------------------------------------===//
+
+
+// Op with a properties struct defined inline.
+def TestOpWithProperties : TEST_Op<"with_properties"> {
+ let assemblyFormat = "prop-dict attr-dict";
+ let arguments = (ins
+ Property<"int64_t">:$a,
+ StrAttr:$b, // Attributes can directly be used here.
+ ArrayProperty<"int64_t", 4>:$array // example of an array
+ );
+}
+
+// Demonstrate how to wrap an existing C++ class named MyPropStruct.
+def MyStructProperty : Property<"MyPropStruct"> {
+ let convertToAttribute = "$_storage.asAttribute($_ctxt)";
+ let convertFromAttribute = "return MyPropStruct::setFromAttr($_storage, $_attr, $_diag);";
+ let hashProperty = "$_storage.hash();";
+}
+
+def TestOpWithWrappedProperties : TEST_Op<"with_wrapped_properties"> {
+ let assemblyFormat = "prop-dict attr-dict";
+ let arguments = (ins
+ MyStructProperty:$prop
+ );
+}
+
+// Op with a properties struct defined out-of-line. The struct has custom
+// printer/parser.
+
+def PropertiesWithCustomPrint : Property<"PropertiesWithCustomPrint"> {
+ let convertToAttribute = [{
+ getPropertiesAsAttribute($_ctxt, $_storage)
+ }];
+ let convertFromAttribute = [{
+ return setPropertiesFromAttribute($_storage, $_attr, $_diag);
+ }];
+ let hashProperty = [{
+ computeHash($_storage);
+ }];
+}
+
+def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
+ let assemblyFormat = "prop-dict attr-dict";
+ let arguments = (ins
+ PropertiesWithCustomPrint:$prop
+ );
+ let extraClassDeclaration = [{
+ void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
+ const Properties &prop);
+ static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result);
+ }];
+ let extraClassDefinition = [{
+ void TestOpWithNiceProperties::printProperties(::mlir::MLIRContext *ctx,
+ ::mlir::OpAsmPrinter &p, const Properties &prop) {
+ customPrintProperties(p, prop.prop);
+ }
+ ::mlir::ParseResult TestOpWithNiceProperties::parseProperties(
+ ::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ Properties &prop = result.getOrAddProperties<Properties>();
+ if (customParseProperties(parser, prop.prop))
+ return failure();
+ return success();
+ }
+ }];
+}
+
+
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 20d3d486c921f..adaa6e1558999 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -434,7 +434,8 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded(OpTy::inferReturnTypes(
context, std::nullopt, values, op->getAttrDictionary(),
- op->getRegions(), inferredReturnTypes))) {
+ op->getPropertiesStorage(), op->getRegions(),
+ inferredReturnTypes))) {
OperationState state(location, OpTy::getOperationName());
// TODO: Expand to regions.
OpTy::build(b, state, values, op->getAttrs());
diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td
index 0d377a454d6f5..ab191cbe7e13b 100644
--- a/mlir/test/mlir-tblgen/constraint-unique.td
+++ b/mlir/test/mlir-tblgen/constraint-unique.td
@@ -69,8 +69,8 @@ def OpC : NS_Op<"op_c"> {
/// Test that an attribute contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
-// CHECK: if (attr && !((attrPred(attr, *op)))) {
-// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK: if (attr && !((attrPred(attr, *op))))
+// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
/// Test that duplicate attribute constraint was not generated.
@@ -78,8 +78,9 @@ def OpC : NS_Op<"op_c"> {
/// Test that a attribute constraint with a
diff erent description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
-// CHECK: if (attr && !((attrPred(attr, *op)))) {
-// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
+// CHECK: if (attr && !((attrPred(attr, *op))))
+// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: another attribute";
/// Test that a successor contraint was generated.
diff --git a/mlir/test/mlir-tblgen/interfaces-as-constraints.td b/mlir/test/mlir-tblgen/interfaces-as-constraints.td
index 093868000ff92..6468e60c70d29 100644
--- a/mlir/test/mlir-tblgen/interfaces-as-constraints.td
+++ b/mlir/test/mlir-tblgen/interfaces-as-constraints.td
@@ -34,13 +34,13 @@ def OpUsingAllOfThose : Op<Test_Dialect, "OpUsingAllOfThose"> {
// CHECK-NEXT: << " must be TypeInterfaceInNamespace instance, but got " << type;
// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
-// CHECK: if (attr && !((attr.isa<TopLevelAttrInterface>()))) {
-// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK: if (attr && !((attr.isa<TopLevelAttrInterface>())))
+// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: TopLevelAttrInterface instance";
// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
-// CHECK: if (attr && !((attr.isa<test::AttrInterfaceInNamespace>()))) {
-// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK: if (attr && !((attr.isa<test::AttrInterfaceInNamespace>())))
+// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: AttrInterfaceInNamespace instance";
// CHECK: TopLevelAttrInterface OpUsingAllOfThose::getAttr1()
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index af1f62221fc07..a10cb0536ebc8 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -68,7 +68,7 @@ def AOp : NS_Op<"a_op", []> {
// DEF: ::mlir::LogicalResult AOpAdaptor::verify
// DEF: ::mlir::Attribute tblgen_aAttr;
-// DEF-NEXT: while (true) {
+// DEF: while (true) {
// DEF-NEXT: if (namedAttrIt == namedAttrRange.end())
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
// DEF-NEXT: if (namedAttrIt->getName() == AOp::getAAttrAttrName(*odsOpName)) {
@@ -217,10 +217,10 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
// DEF: ::mlir::LogicalResult AgetOpAdaptor::verify
// DEF: ::mlir::Attribute tblgen_aAttr;
-// DEF-NEXT: while (true)
+// DEF: while (true)
// DEF: ::mlir::Attribute tblgen_bAttr;
// DEF-NEXT: ::mlir::Attribute tblgen_cAttr;
-// DEF-NEXT: while (true)
+// DEF: while (true)
// DEF: if (tblgen_aAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_bAttr && !((some-condition)))
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index ed54ef2238019..618197c76c2a0 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -126,7 +126,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// DEFS-LABEL: NS::AOp definitions
-// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
+// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::EmptyProperties properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
// DEFS-NEXT: return odsRegions.drop_front(1);
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()
diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index 91c0011ae660c..7083d70be45de 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -64,7 +64,7 @@ def OptionalGroupA : TestFormat_Op<[{
// CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr())
// CHECK: parser.parseKeyword("bar")
// CHECK-LABEL: OptionalGroupB::print
-// CHECK: if (!(*this)->getAttr("a"))
+// CHECK: if (!getAAttr())
// CHECK-NEXT: odsPrinter << ' ' << "foo"
// CHECK-NEXT: else
// CHECK-NEXT: odsPrinter << ' ' << "bar"
@@ -74,7 +74,8 @@ def OptionalGroupB : TestFormat_Op<[{
// Optional group anchored on a default-valued attribute:
// CHECK-LABEL: OptionalGroupC::parse
-// CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
+
+// CHECK: if (getAAttr() && getAAttr() != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
// CHECK-NEXT: odsPrinter << ' ';
// CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr());
// CHECK-NEXT: }
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index d49bffa9cb441..606d4c250d88b 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -152,7 +152,7 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
// CHECK-NOT: }
-// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
+// CHECK: ::mlir::Type odsInferredType0 = odsInferredTypeAttr0.getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL4 : NS_Op<"two_inference_edges", [
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 4aedb5e83f669..d20ffbe46caaa 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -5,8 +5,8 @@ func.func @verifyFusedLocs(%arg0 : i32) -> i32 {
%0 = "test.op_a"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
%result = "test.op_a"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b")
- // CHECK: "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
- // CHECK: "test.op_b"(%arg0) {attr = 20 : i32} : (i32) -> i32 loc(fused["b", "a"])
+ // CHECK: "test.op_b"(%arg0) <{attr = 10 : i32}> : (i32) -> i32 loc("a")
+ // CHECK: "test.op_b"(%arg0) <{attr = 20 : i32}> : (i32) -> i32 loc(fused["b", "a"])
return %result : i32
}
@@ -41,7 +41,7 @@ func.func @verifyZeroArg() -> i32 {
// CHECK-LABEL: testIgnoreArgMatch
// CHECK-SAME: (%{{[a-z0-9]*}}: i32 loc({{[^)]*}}), %[[ARG1:[a-z0-9]*]]: i32 loc({{[^)]*}}),
func.func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
- // CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) {f = 15 : i64}
+ // CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) <{f = 15 : i64}>
"test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42, e = 24, f = 15} : (i32, i32, i32) -> ()
// CHECK: test.ignore_arg_match_src
@@ -57,7 +57,7 @@ func.func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
// CHECK-LABEL: verifyInterleavedOperandAttribute
// CHECK-SAME: %[[ARG0:.*]]: i32 loc({{[^)]*}}), %[[ARG1:.*]]: i32 loc({{[^)]*}})
func.func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) {
- // CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) {attr1 = 15 : i64, attr2 = 42 : i64}
+ // CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) <{attr1 = 15 : i64, attr2 = 42 : i64}>
"test.interleaved_operand_attr1"(%arg0, %arg1) {attr1 = 15, attr2 = 42} : (i32, i32) -> ()
return
}
@@ -69,13 +69,13 @@ func.func @verifyBenefit(%arg0 : i32) -> i32 {
%2 = "test.op_g"(%1) : (i32) -> i32
// CHECK: "test.op_f"(%arg0)
- // CHECK: "test.op_b"(%arg0) {attr = 34 : i32}
+ // CHECK: "test.op_b"(%arg0) <{attr = 34 : i32}>
return %0 : i32
}
// CHECK-LABEL: verifyNativeCodeCall
func.func @verifyNativeCodeCall(%arg0: i32, %arg1: i32) -> (i32, i32) {
- // CHECK: %0 = "test.native_code_call2"(%arg0) {attr = [42, 24]} : (i32) -> i32
+ // CHECK: %0 = "test.native_code_call2"(%arg0) <{attr = [42, 24]}> : (i32) -> i32
// CHECK: return %0, %arg1
%0 = "test.native_code_call1"(%arg0, %arg1) {choice = true, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32)
%1 = "test.native_code_call1"(%arg0, %arg1) {choice = false, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32)
@@ -215,7 +215,7 @@ func.func @symbolBinding(%arg0: i32) -> i32 {
// An op with one use is matched.
// CHECK: %0 = "test.symbol_binding_b"(%arg0)
// CHECK: %1 = "test.symbol_binding_c"(%0)
- // CHECK: %2 = "test.symbol_binding_d"(%0, %1) {attr = 42 : i64}
+ // CHECK: %2 = "test.symbol_binding_d"(%0, %1) <{attr = 42 : i64}>
%0 = "test.symbol_binding_a"(%arg0) {attr = 42} : (i32) -> (i32)
// An op without any use is not matched.
@@ -239,21 +239,21 @@ func.func @symbolBindingNoResult(%arg0: i32) {
// CHECK-LABEL: succeedMatchOpAttr
func.func @succeedMatchOpAttr() -> i32 {
- // CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}
+ // CHECK: "test.match_op_attribute2"() <{default_valued_attr = 3 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
// CHECK-LABEL: succeedMatchMissingOptionalAttr
func.func @succeedMatchMissingOptionalAttr() -> i32 {
- // CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, required_attr = 1 : i32}
+ // CHECK: "test.match_op_attribute2"() <{default_valued_attr = 3 : i32, more_attr = 4 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
// CHECK-LABEL: succeedMatchMissingDefaultValuedAttr
func.func @succeedMatchMissingDefaultValuedAttr() -> i32 {
- // CHECK: "test.match_op_attribute2"() {default_valued_attr = 42 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}
+ // CHECK: "test.match_op_attribute2"() <{default_valued_attr = 42 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
@@ -267,7 +267,7 @@ func.func @failedMatchAdditionalConstraintNotSatisfied() -> i32 {
// CHECK-LABEL: verifyConstantAttr
func.func @verifyConstantAttr(%arg0 : i32) -> i32 {
- // CHECK: "test.op_b"(%arg0) {attr = 17 : i32} : (i32) -> i32 loc("a")
+ // CHECK: "test.op_b"(%arg0) <{attr = 17 : i32}> : (i32) -> i32 loc("a")
%0 = "test.op_c"(%arg0) : (i32) -> i32 loc("a")
return %0 : i32
}
@@ -275,12 +275,12 @@ func.func @verifyConstantAttr(%arg0 : i32) -> i32 {
// CHECK-LABEL: verifyUnitAttr
func.func @verifyUnitAttr() -> (i32, i32) {
// Unit attribute present in the matched op is propagated as attr2.
- // CHECK: "test.match_op_attribute4"() {attr1, attr2} : () -> i32
+ // CHECK: "test.match_op_attribute4"() <{attr1, attr2}> : () -> i32
%0 = "test.match_op_attribute3"() {attr} : () -> i32
// Since the original op doesn't have the unit attribute, the new op
// only has the constant-constructed unit attribute attr1.
- // CHECK: "test.match_op_attribute4"() {attr1} : () -> i32
+ // CHECK: "test.match_op_attribute4"() <{attr1}> : () -> i32
%1 = "test.match_op_attribute3"() : () -> i32
return %0, %1 : i32, i32
}
@@ -291,7 +291,7 @@ func.func @verifyUnitAttr() -> (i32, i32) {
// CHECK-LABEL: testConstOp
func.func @testConstOp() -> (i32) {
- // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
+ // CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: return [[C0]]
@@ -300,7 +300,7 @@ func.func @testConstOp() -> (i32) {
// CHECK-LABEL: testConstOpUsed
func.func @testConstOpUsed() -> (i32) {
- // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
+ // CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
@@ -312,11 +312,11 @@ func.func @testConstOpUsed() -> (i32) {
// CHECK-LABEL: testConstOpReplaced
func.func @testConstOpReplaced() -> (i32) {
- // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
+ // CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
%1 = "test.constant"() {value = 2 : i32} : () -> i32
- // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
+ // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) <{value = 2 : i32}
%2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
// CHECK: [[V0]]
@@ -325,10 +325,10 @@ func.func @testConstOpReplaced() -> (i32) {
// CHECK-LABEL: testConstOpMatchFailure
func.func @testConstOpMatchFailure() -> (i64) {
- // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
+ // CHECK-DAG: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i64} : () -> i64
- // CHECK-DAG: [[C1:%.+]] = "test.constant"() {value = 2
+ // CHECK-DAG: [[C1:%.+]] = "test.constant"() <{value = 2
%1 = "test.constant"() {value = 2 : i64} : () -> i64
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
@@ -340,7 +340,7 @@ func.func @testConstOpMatchFailure() -> (i64) {
// CHECK-LABEL: testConstOpMatchNonConst
func.func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
- // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
+ // CHECK-DAG: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], %arg0)
@@ -358,14 +358,14 @@ func.func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
// CHECK-LABEL: verifyI32EnumAttr
func.func @verifyI32EnumAttr() -> i32 {
- // CHECK: "test.i32_enum_attr"() {attr = 10 : i32}
+ // CHECK: "test.i32_enum_attr"() <{attr = 10 : i32}
%0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
return %0 : i32
}
// CHECK-LABEL: verifyI64EnumAttr
func.func @verifyI64EnumAttr() -> i32 {
- // CHECK: "test.i64_enum_attr"() {attr = 10 : i64}
+ // CHECK: "test.i64_enum_attr"() <{attr = 10 : i64}
%0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
return %0 : i32
}
@@ -522,7 +522,7 @@ func.func @generateVariadicOutputOpInNestedPattern() -> (i32) {
// CHECK-LABEL: redundantTest
func.func @redundantTest(%arg0: i32) -> i32 {
%0 = "test.op_m"(%arg0) : (i32) -> i32
- // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32
+ // CHECK: "test.op_m"(%arg0) <{optional_attr = 314159265 : i32}> : (i32) -> i32
return %0 : i32
}
diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index 39fb44f27695d..24555034d2677 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -24,6 +24,7 @@ func.func @testCreateFunctions(%arg0 : tensor<10xf32, !test.smpla>, %arg1 : tens
// -----
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{incompatible with return type}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
@@ -32,6 +33,7 @@ func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// -----
func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{operand type mismatch}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
return
@@ -40,6 +42,7 @@ func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : ten
// -----
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
+ // expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{required first operand and result to match}}
%bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 7d2e03ecfe278..b4f71fb45b376 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -177,6 +177,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
StringSwitch<FormatToken::Kind>(str)
.Case("attr-dict", FormatToken::kw_attr_dict)
.Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
+ .Case("prop-dict", FormatToken::kw_prop_dict)
.Case("custom", FormatToken::kw_custom)
.Case("functional-type", FormatToken::kw_functional_type)
.Case("oilist", FormatToken::kw_oilist)
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index e5fd04a24b2f5..da30e35f13532 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -60,6 +60,7 @@ class FormatToken {
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
+ kw_prop_dict,
kw_custom,
kw_functional_type,
kw_oilist,
@@ -287,6 +288,7 @@ class DirectiveElement : public FormatElementBase<FormatElement::Directive> {
/// These are the kinds of directives.
enum Kind {
AttrDict,
+ PropDict,
Custom,
FunctionalType,
OIList,
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 40b688f2b96ca..698569c790e93 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -37,5 +37,6 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
void OpClass::finalize() {
Class::finalize();
declare<VisibilityDeclaration>(Visibility::Public);
- declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
+ declare<ExtraClassDeclaration>(extraClassDeclaration.str(),
+ extraClassDefinition);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index e3e7ae087085a..dc257fd97b0e0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -14,20 +14,25 @@
#include "OpClass.h"
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/SideEffects.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -43,6 +48,10 @@ static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
+static const char *const propertyStorage = "propStorage";
+static const char *const propertyValue = "propValue";
+static const char *const propertyAttr = "propAttr";
+static const char *const propertyDiag = "propDiag";
/// The names of the implicit attributes that contain variadic operand and
/// result segment sizes.
@@ -103,7 +112,7 @@ static const char *const attrSizedSegmentValueRangeCalcCode = R"(
///
/// {0}: The code to get the attribute.
static const char *const adapterSegmentSizeAttrInitCode = R"(
- assert(odsAttrs && "missing segment size attribute for op");
+ assert({0} && "missing segment size attribute for op");
auto sizeAttr = {0}.cast<::mlir::DenseI32ArrayAttr>();
)";
/// The code snippet to initialize the sizes for the value range calculation.
@@ -260,6 +269,10 @@ class OpOrAdaptorHelper {
assert(attrMetadata.count(attrName) && "expected attribute metadata");
return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
+ if (hasProperties()) {
+ assert(!isNamed);
+ return os << "getProperties()." << attrName;
+ }
return os << formatv(subrangeGetAttr, getAttrName(attrName),
attr.lowerBound, attr.upperBound, getAttrRange(),
isNamed ? "Named" : "");
@@ -324,6 +337,19 @@ class OpOrAdaptorHelper {
return attrMetadata;
}
+ /// Returns whether to emit a `Properties` struct for this operation or not.
+ bool hasProperties() const {
+ if (!op.getProperties().empty())
+ return true;
+ if (!op.getDialect().usePropertiesForAttributes())
+ return false;
+ return llvm::any_of(getAttrMetadata(),
+ [](const std::pair<StringRef, AttributeMetadata> &it) {
+ return !it.second.constraint ||
+ !it.second.constraint->isDerivedAttr();
+ });
+ }
+
private:
// Compute the attribute metadata.
void computeAttrMetadata();
@@ -418,6 +444,9 @@ class OpEmitter {
// Generates the `getOperationName` method for this op.
void genOpNameGetter();
+ // Generates code to manage the properties, if any!
+ void genPropertiesSupport();
+
// Generates getters for the attributes.
void genAttrGetters();
@@ -642,6 +671,20 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
}
}
+// Return true if a verifier can be emitted for the attribute: it is not a
+// derived attribute, it has a predicate, its condition is not empty, and, for
+// adaptors, the condition does not reference the op.
+static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
+ if (attr.isDerivedAttr())
+ return false;
+ Pred pred = attr.getPredicate();
+ if (pred.isNull())
+ return false;
+ std::string condition = pred.getCondition();
+ return !condition.empty() &&
+ (!StringRef(condition).contains("$_op") || isEmittingForOp);
+}
+
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
//
@@ -654,9 +697,11 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
// that depend on the validity of these attributes, e.g. segment size attributes
// and operand or result getters.
// 3. Verify the constraints on all present attributes.
-static void genAttributeVerifier(
- const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
- const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
+static void
+genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
+ MethodBody &body,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool useProperties) {
if (emitHelper.getAttrMetadata().empty())
return;
@@ -691,7 +736,8 @@ static void genAttributeVerifier(
// {0}: Code to get the name of the attribute.
// {1}: The emit error prefix.
// {2}: The name of the attribute.
- const char *const findRequiredAttr = R"(while (true) {{
+ const char *const findRequiredAttr = R"(
+while (true) {{
if (namedAttrIt == namedAttrRange.end())
return {1}"requires attribute '{2}'");
if (namedAttrIt->getName() == {0}) {{
@@ -714,20 +760,6 @@ static void genAttributeVerifier(
break;
})";
- // Return true if a verifier can be emitted for the attribute: it is not a
- // derived attribute, it has a predicate, its condition is not empty, and, for
- // adaptors, the condition does not reference the op.
- const auto canEmitVerifier = [&](Attribute attr) {
- if (attr.isDerivedAttr())
- return false;
- Pred pred = attr.getPredicate();
- if (pred.isNull())
- return false;
- std::string condition = pred.getCondition();
- return !condition.empty() && (!StringRef(condition).contains("$_op") ||
- emitHelper.isEmittingForOp());
- };
-
// Emit the verifier for the attribute.
const auto emitVerifier = [&](Attribute attr, StringRef attrName,
StringRef varName) {
@@ -750,58 +782,74 @@ static void genAttributeVerifier(
return (tblgenNamePrefix + attrName).str();
};
- body.indent() << formatv("auto namedAttrRange = {0};\n",
- emitHelper.getAttrRange());
- body << "auto namedAttrIt = namedAttrRange.begin();\n";
-
- // Iterate over the attributes in sorted order. Keep track of the optional
- // attributes that may be encountered along the way.
- SmallVector<const AttributeMetadata *> optionalAttrs;
- for (const std::pair<StringRef, AttributeMetadata> &it :
- emitHelper.getAttrMetadata()) {
- const AttributeMetadata &metadata = it.second;
- if (!metadata.isRequired) {
- optionalAttrs.push_back(&metadata);
- continue;
+ body.indent();
+ if (useProperties) {
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ body << formatv(
+ "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
+ it.first);
+ const AttributeMetadata &metadata = it.second;
+ if (metadata.isRequired)
+ body << formatv(
+ "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
+ it.first, emitHelper.emitErrorPrefix());
}
+ } else {
+ body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange());
+ body << "auto namedAttrIt = namedAttrRange.begin();\n";
+
+ // Iterate over the attributes in sorted order. Keep track of the optional
+ // attributes that may be encountered along the way.
+ SmallVector<const AttributeMetadata *> optionalAttrs;
+
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ const AttributeMetadata &metadata = it.second;
+ if (!metadata.isRequired) {
+ optionalAttrs.push_back(&metadata);
+ continue;
+ }
- body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv("::mlir::Attribute {0};\n",
- getVarName(optional->attrName));
- }
- body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
- emitHelper.emitErrorPrefix(), it.first);
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv(checkOptionalAttr,
- emitHelper.getAttrName(optional->attrName),
- optional->attrName);
- }
- body << "\n ++namedAttrIt;\n}\n";
- optionalAttrs.clear();
- }
- // Get trailing optional attributes.
- if (!optionalAttrs.empty()) {
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv("::mlir::Attribute {0};\n",
- getVarName(optional->attrName));
+ body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
+ emitHelper.emitErrorPrefix(), it.first);
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
+ }
+ body << "\n ++namedAttrIt;\n}\n";
+ optionalAttrs.clear();
}
- body << checkTrailingAttrs;
- for (const AttributeMetadata *optional : optionalAttrs) {
- body << formatv(checkOptionalAttr,
- emitHelper.getAttrName(optional->attrName),
- optional->attrName);
+ // Get trailing optional attributes.
+ if (!optionalAttrs.empty()) {
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << checkTrailingAttrs;
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
+ }
+ body << "\n ++namedAttrIt;\n}\n";
}
- body << "\n ++namedAttrIt;\n}\n";
}
body.unindent();
- // Emit the checks for segment attributes first so that the other constraints
- // can call operand and result getters.
+ // Emit the checks for segment attributes first so that the other
+ // constraints can call operand and result getters.
genNativeTraitAttrVerifier(body, emitHelper);
+ bool isEmittingForOp = emitHelper.isEmittingForOp();
for (const auto &namedAttr : emitHelper.getOp().getAttributes())
- if (canEmitVerifier(namedAttr.attr))
+ if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp))
emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
}
@@ -834,6 +882,7 @@ OpEmitter::OpEmitter(const Operator &op,
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
+ genPropertiesSupport();
genAttrGetters();
genAttrSetters();
genOptionalAttrRemovers();
@@ -989,6 +1038,274 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx,
<< ";\n";
}
+void OpEmitter::genPropertiesSupport() {
+ if (!emitHelper.hasProperties())
+ return;
+ using ConstArgument =
+ llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
+
+ SmallVector<ConstArgument> attrOrProperties;
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
+ attrOrProperties.push_back(&it.second);
+ }
+ for (const NamedProperty &prop : op.getProperties())
+ attrOrProperties.push_back(&prop);
+ if (attrOrProperties.empty())
+ return;
+ auto &setPropMethod =
+ opClass
+ .addStaticMethod(
+ "::mlir::LogicalResult", "setPropertiesFromAttr",
+ MethodParameter("Properties &", "prop"),
+ MethodParameter("::mlir::Attribute", "attr"),
+ MethodParameter("::mlir::InFlightDiagnostic *", "diag"))
+ ->body();
+ auto &getPropMethod =
+ opClass
+ .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
+ MethodParameter("::mlir::MLIRContext *", "ctx"),
+ MethodParameter("const Properties &", "prop"))
+ ->body();
+ auto &hashMethod =
+ opClass
+ .addStaticMethod("llvm::hash_code", "computePropertiesHash",
+ MethodParameter("const Properties &", "prop"))
+ ->body();
+ auto &getInherentAttrMethod =
+ opClass
+ .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
+ MethodParameter("const Properties &", "prop"),
+ MethodParameter("llvm::StringRef", "name"))
+ ->body();
+ auto &setInherentAttrMethod =
+ opClass
+ .addStaticMethod("void", "setInherentAttr",
+ MethodParameter("Properties &", "prop"),
+ MethodParameter("llvm::StringRef", "name"),
+ MethodParameter("mlir::Attribute", "value"))
+ ->body();
+ auto &populateInherentAttrsMethod =
+ opClass
+ .addStaticMethod("void", "populateInherentAttrs",
+ MethodParameter("const Properties &", "prop"),
+ MethodParameter("::mlir::NamedAttrList &", "attrs"))
+ ->body();
+ auto &verifyInherentAttrsMethod =
+ opClass
+ .addStaticMethod(
+ "::mlir::LogicalResult", "verifyInherentAttrs",
+ MethodParameter("::mlir::OperationName", "opName"),
+ MethodParameter("::mlir::NamedAttrList &", "attrs"),
+ MethodParameter(
+ "llvm::function_ref<::mlir::InFlightDiagnostic()>",
+ "getDiag"))
+ ->body();
+
+ opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties");
+
+ // Convert the property to the attribute form.
+
+ setPropMethod << R"decl(
+ ::mlir::DictionaryAttr dict = dyn_cast<::mlir::DictionaryAttr>(attr);
+ if (!dict) {
+ if (diag)
+ *diag << "expected DictionaryAttr to set properties";
+ return failure();
+ }
+ )decl";
+ // TODO: properties might be optional as well.
+ const char *propFromAttrFmt = R"decl(;
+ {{
+ auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
+ ::mlir::InFlightDiagnostic *propDiag) {{
+ {0};
+ };
+ auto attr = dict.get("{1}");
+ if (!attr) {{
+ if (diag)
+ *diag << "expected key entry for {1} in DictionaryAttr to set "
+ "Properties.";
+ return failure();
+ }
+ if (failed(setFromAttr(prop.{1}, attr, diag))) return ::mlir::failure();
+ }
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ setPropMethod << formatv(propFromAttrFmt,
+ tgfmt(prop.getConvertFromAttributeCall(),
+ &fctx.addSubst("_attr", propertyAttr)
+ .addSubst("_storage", propertyStorage)
+ .addSubst("_diag", propertyDiag)),
+ name);
+ } else {
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ setPropMethod << formatv(R"decl(
+ {{
+ auto &propStorage = prop.{0};
+ auto attr = dict.get("{0}");
+ if (attr || /*isRequired=*/{1}) {{
+ if (!attr) {{
+ if (diag)
+ *diag << "expected key entry for {0} in DictionaryAttr to set "
+ "Properties.";
+ return failure();
+ }
+ auto convertedAttr = dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
+ if (convertedAttr) {{
+ propStorage = convertedAttr;
+ } else {{
+ if (diag)
+ *diag << "Invalid attribute `{0}` in property conversion: " << attr;
+ return failure();
+ }
+ }
+ }
+)decl",
+ name, namedAttr->isRequired);
+ }
+ }
+ setPropMethod << " return ::mlir::success();\n";
+
+ // Convert the attribute form to the property.
+
+ getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
+ << " ::mlir::Builder odsBuilder{ctx};\n";
+ const char *propToAttrFmt = R"decl(
+ {
+ const auto &propStorage = prop.{0};
+ attrs.push_back(odsBuilder.getNamedAttr("{0}",
+ {1}));
+ }
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ getPropMethod << formatv(
+ propToAttrFmt, name,
+ tgfmt(prop.getConvertToAttributeCall(),
+ &fctx.addSubst("_ctxt", "ctx")
+ .addSubst("_storage", propertyStorage)));
+ continue;
+ }
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ getPropMethod << formatv(R"decl(
+ {{
+ const auto &propStorage = prop.{0};
+ if (propStorage)
+ attrs.push_back(odsBuilder.getNamedAttr("{0}",
+ propStorage));
+ }
+)decl",
+ name);
+ }
+ getPropMethod << R"decl(
+ if (!attrs.empty())
+ return odsBuilder.getDictionaryAttr(attrs);
+ return {};
+)decl";
+
+ // Hashing for the property
+
+ const char *propHashFmt = R"decl(
+ auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
+ return {1};
+ };
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ hashMethod << formatv(propHashFmt, name,
+ tgfmt(prop.getHashPropertyCall(),
+ &fctx.addSubst("_storage", propertyStorage)));
+ }
+ }
+ hashMethod << " return llvm::hash_combine(";
+ llvm::interleaveComma(
+ attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ hashMethod << "\n hash_" << namedProperty->name << "(prop."
+ << namedProperty->name << ")";
+ return;
+ }
+ const auto *namedAttr =
+ attrOrProp.dyn_cast<const AttributeMetadata *>();
+ StringRef name = namedAttr->attrName;
+ hashMethod << "\n llvm::hash_value(prop." << name
+ << ".getAsOpaquePointer())";
+ });
+ hashMethod << ");\n";
+
+ const char *getInherentAttrMethodFmt = R"decl(
+ if (name == "{0}")
+ return prop.{0};
+)decl";
+ const char *setInherentAttrMethodFmt = R"decl(
+ if (name == "{0}") {{
+ prop.{0} = dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
+ return;
+ }
+)decl";
+ const char *populateInherentAttrsMethodFmt = R"decl(
+ if (prop.{0}) attrs.append("{0}", prop.{0});
+)decl";
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedAttr =
+ attrOrProp.dyn_cast<const AttributeMetadata *>()) {
+ StringRef name = namedAttr->attrName;
+ getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
+ setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
+ populateInherentAttrsMethod
+ << formatv(populateInherentAttrsMethodFmt, name);
+ continue;
+ }
+ }
+ getInherentAttrMethod << " return std::nullopt;\n";
+
+ // Emit the verifiers method for backward compatibility with the generic
+ // syntax. This method verifies the constraint on the properties attributes
+ // before they are set, since dyn_cast<> will silently omit failures.
+ for (const auto &attrOrProp : attrOrProperties) {
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ if (!namedAttr || !namedAttr->constraint)
+ continue;
+ Attribute attr = *namedAttr->constraint;
+ std::optional<StringRef> constraintFn =
+ staticVerifierEmitter.getAttrConstraintFn(attr);
+ if (!constraintFn)
+ continue;
+ if (canEmitAttrVerifier(attr,
+ /*isEmittingForOp=*/false)) {
+ std::string name = op.getGetterName(namedAttr->attrName);
+ verifyInherentAttrsMethod
+ << formatv(R"(
+ {{
+ ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
+ if (attr && ::mlir::failed({1}(attr, "{2}", getDiag)))
+ return ::mlir::failure();
+ }
+)",
+ name, constraintFn, namedAttr->attrName);
+ }
+ }
+ verifyInherentAttrsMethod << " return ::mlir::success();";
+}
+
void OpEmitter::genAttrGetters() {
FmtContext fctx;
fctx.withBuilder("::mlir::Builder((*this)->getContext())");
@@ -999,9 +1316,9 @@ void OpEmitter::genAttrGetters() {
method->body() << " " << attr.getDerivedCodeBody() << "\n";
};
- // Generate named accessor with Attribute return type. This is a wrapper class
- // that allows referring to the attributes via accessors instead of having to
- // use the string interface for better compile time verification.
+ // Generate named accessor with Attribute return type. This is a wrapper
+ // class that allows referring to the attributes via accessors instead of
+ // having to use the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
Attribute attr) {
auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
@@ -1086,7 +1403,8 @@ void OpEmitter::genAttrGetters() {
body << " {" << name << "AttrName(),\n"
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
.withBuilder("odsBuilder")
- .addSubst("_ctxt", "ctx"))
+ .addSubst("_ctxt", "ctx")
+ .addSubst("_storage", "ctx"))
<< "}";
},
",\n");
@@ -1175,19 +1493,29 @@ void OpEmitter::genAttrSetters() {
void OpEmitter::genOptionalAttrRemovers() {
// Generate methods for removing optional attributes, instead of having to
// use the string interface. Enables better compile time verification.
- auto emitRemoveAttr = [&](StringRef name) {
+ auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
auto upperInitial = name.take_front().upper();
auto *method = opClass.addMethod("::mlir::Attribute",
op.getRemoverName(name) + "Attr");
if (!method)
return;
- method->body() << formatv(" return (*this)->removeAttr({0}AttrName());",
+ if (useProperties) {
+ method->body() << formatv(R"(
+ auto &attr = getProperties().{0};
+ attr = {{};
+ return attr;
+)",
+ name);
+ return;
+ }
+ method->body() << formatv("return (*this)->removeAttr({0}AttrName());",
op.getGetterName(name));
};
for (const NamedAttribute &namedAttr : op.getAttributes())
if (namedAttr.attr.isOptional())
- emitRemoveAttr(namedAttr.name);
+ emitRemoveAttr(namedAttr.name,
+ op.getDialect().usePropertiesForAttributes());
}
// Generates the code to compute the start and end index of an operand or result
@@ -1417,9 +1745,15 @@ void OpEmitter::genNamedOperandSetters() {
"::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands) {
- body << formatv(
- ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
- emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ if (emitHelper.hasProperties())
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+ "{getOperandSegmentSizesAttrName(), getProperties().{1}})",
+ i, operandSegmentAttrName);
+ else
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+ emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
}
body << ");\n";
@@ -1623,6 +1957,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, {1}.operands,
{1}.attributes.getDictionary({1}.getContext()),
+ {1}.getRawProperties(),
{1}.regions, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
@@ -1645,11 +1980,17 @@ void OpEmitter::genSeparateArgParamBuilder() {
// Automatically create the 'result_segment_sizes' attribute using
// the length of the type ranges.
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- std::string getterName = op.getGetterName(resultSegmentAttrName);
- body << " " << builderOpState << ".addAttribute(" << getterName
- << "AttrName(" << builderOpState << ".name), "
- << "odsBuilder.getDenseI32ArrayAttr({";
-
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " (" << builderOpState
+ << ".getOrAddProperties<Properties>()." << resultSegmentAttrName
+ << " = \n"
+ " odsBuilder.getDenseI32ArrayAttr({";
+ } else {
+ std::string getterName = op.getGetterName(resultSegmentAttrName);
+ body << " " << builderOpState << ".addAttribute(" << getterName
+ << "AttrName(" << builderOpState << ".name), "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ }
interleaveComma(
llvm::seq<int>(0, op.getNumResults()), body, [&](int i) {
const NamedTypeConstraint &result = op.getResult(i);
@@ -1748,6 +2089,32 @@ void OpEmitter::genPopulateDefaultAttributes() {
}))
return;
+ if (op.getDialect().usePropertiesForAttributes()) {
+ SmallVector<MethodParameter> paramList;
+ paramList.emplace_back("::mlir::OperationName", "opName");
+ paramList.emplace_back("Properties &", "properties");
+ auto *m =
+ opClass.addStaticMethod("void", "populateDefaultProperties", paramList);
+ ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
+ auto &body = m->body();
+ body.indent();
+ body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
+ for (const NamedAttribute &namedAttr : op.getAttributes()) {
+ auto &attr = namedAttr.attr;
+ if (!attr.hasDefaultValue() || attr.isOptional())
+ continue;
+ StringRef name = namedAttr.name;
+ FmtContext fctx;
+ fctx.withBuilder(odsBuilder);
+ body << "if (!properties." << name << ")\n"
+ << " properties." << name << " = "
+ << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ tgfmt(attr.getDefaultValue(), &fctx)))
+ << ";\n";
+ }
+ return;
+ }
+
SmallVector<MethodParameter> paramList;
paramList.emplace_back("const ::mlir::OperationName &", "opName");
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
@@ -1830,6 +2197,7 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, operands,
{1}.attributes.getDictionary({1}.getContext()),
+ {1}.getRawProperties(),
{1}.regions, inferredReturnTypes))) {{)",
opClass.getClassName(), builderOpState);
if (numVariadicResults == 0 || numNonVariadicResults != 0)
@@ -2147,6 +2515,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
operand->isOptional());
continue;
}
+ if (const auto *operand = arg.dyn_cast<NamedProperty *>()) {
+ // TODO
+ continue;
+ }
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
const Attribute &attr = namedAttr.attr;
@@ -2207,12 +2579,19 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
<< " ::llvm::SmallVector<int32_t> rangeSegments;\n"
<< " for (::mlir::ValueRange range : " << argName << ")\n"
<< " rangeSegments.push_back(range.size());\n"
- << " " << builderOpState << ".addAttribute("
- << op.getGetterName(
- operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
- << "AttrName(" << builderOpState << ".name), " << odsBuilder
- << ".getDenseI32ArrayAttr(rangeSegments));"
- << " }\n";
+ << " auto rangeAttr = " << odsBuilder
+ << ".getDenseI32ArrayAttr(rangeSegments);\n";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " " << builderOpState << ".getOrAddProperties<Properties>()."
+ << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
+ << " = rangeAttr;";
+ } else {
+ body << " " << builderOpState << ".addAttribute("
+ << op.getGetterName(
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
+ << "AttrName(" << builderOpState << ".name), rangeAttr);";
+ }
+ body << " }\n";
continue;
}
@@ -2224,9 +2603,15 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
std::string sizes = op.getGetterName(operandSegmentAttrName);
- body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
- << builderOpState << ".name), "
- << "odsBuilder.getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " (" << builderOpState << ".getOrAddProperties<Properties>()."
+ << operandSegmentAttrName << "= "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ } else {
+ body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
+ << builderOpState << ".name), "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ }
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
const NamedTypeConstraint &operand = op.getOperand(i);
if (!operand.isVariableLength()) {
@@ -2272,13 +2657,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, op.getGetterName(namedAttr.name),
- constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
+ builderOpState, namedAttr.name,
+ constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ } else {
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+ builderOpState, op.getGetterName(namedAttr.name),
+ constBuildAttrFromParam(attr, fctx, namedAttr.name));
+ }
} else {
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, op.getGetterName(namedAttr.name),
- namedAttr.name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
+ builderOpState, namedAttr.name);
+ } else {
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+ builderOpState, op.getGetterName(namedAttr.name),
+ namedAttr.name);
+ }
}
if (emitNotNullCheck)
body.unindent() << " }\n";
@@ -2448,6 +2844,8 @@ void OpEmitter::genSideEffectInterfaceMethods() {
++operandIt;
continue;
}
+ if (arg.is<NamedProperty *>())
+ continue;
const NamedAttribute *attr = arg.get<NamedAttribute *>();
if (attr->attr.getBaseAttr().isSymbolRefAttr())
resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
@@ -2544,7 +2942,6 @@ void OpEmitter::genTypeInterfaceMethods() {
continue;
const InferredResultType &infer = op.getInferredResultType(i);
std::string typeStr;
- body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
if (infer.isArg()) {
// If this is an operand, just index into operand list to access the
// type.
@@ -2558,9 +2955,22 @@ void OpEmitter::genTypeInterfaceMethods() {
} else {
auto *attr =
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
- typeStr = ("attributes.get(\"" + attr->name +
- "\").cast<::mlir::TypedAttr>().getType()")
- .str();
+ body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
+ << " = ";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << "(properties ? properties.as<Properties *>()->"
+ << attr->name
+ << " : attributes.get(\"" + attr->name +
+ "\").dyn_cast_or_null<::mlir::TypedAttr>());\n";
+ } else {
+ body << "attributes.get(\"" + attr->name +
+ "\").dyn_cast_or_null<::mlir::TypedAttr>();\n";
+ }
+ body << " if (!odsInferredTypeAttr" << inferredTypeIdx
+ << ") return ::mlir::failure();\n";
+ typeStr =
+ ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
+ .str();
}
} else if (std::optional<StringRef> builder =
op.getResult(infer.getResultIndex())
@@ -2572,7 +2982,8 @@ void OpEmitter::genTypeInterfaceMethods() {
} else {
continue;
}
- body << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
+ << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
constructedIndices[i] = inferredTypeIdx - 1;
}
}
@@ -2615,9 +3026,11 @@ void OpEmitter::genVerifier() {
opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
auto &implBody = implMethod->body();
+ bool useProperties = emitHelper.hasProperties();
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
+ genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter,
+ useProperties);
genOperandResultVerifier(implBody, op.getOperands(), "operand");
genOperandResultVerifier(implBody, op.getResults(), "result");
@@ -3003,11 +3416,110 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/false) {
+ genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
+ bool useProperties = emitHelper.hasProperties();
+ if (useProperties) {
+ // Define the properties struct with multiple members.
+ using ConstArgument =
+ llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
+ SmallVector<ConstArgument> attrOrProperties;
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
+ attrOrProperties.push_back(&it.second);
+ }
+ for (const NamedProperty &prop : op.getProperties())
+ attrOrProperties.push_back(&prop);
+ assert(!attrOrProperties.empty());
+ std::string declarations = " struct Properties {\n";
+ llvm::raw_string_ostream os(declarations);
+ for (const auto &attrOrProp : attrOrProperties) {
+ if (const auto *namedProperty =
+ attrOrProp.dyn_cast<const NamedProperty *>()) {
+ StringRef name = namedProperty->name;
+ if (name.empty())
+ report_fatal_error("missing name for property");
+ std::string camelName =
+ convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
+ auto &prop = namedProperty->prop;
+ // Generate the data member using the storage type.
+ os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
+ << " " << name << "Ty " << name;
+ if (prop.hasDefaultValue())
+ os << " = " << prop.getDefaultValue();
+
+ // Emit accessors using the interface type.
+ const char *accessorFmt = R"decl(;
+ {0} get{1}() {
+ auto &propStorage = this->{2};
+ return {3};
+ }
+ void set{1}(const {0} &propValue) {
+ auto &propStorage = this->{2};
+ {4};
+ }
+)decl";
+ FmtContext fctx;
+ os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name,
+ tgfmt(prop.getConvertFromStorageCall(),
+ &fctx.addSubst("_storage", propertyStorage)),
+ tgfmt(prop.getAssignToStorageCall(),
+ &fctx.addSubst("_value", propertyValue)
+ .addSubst("_storage", propertyStorage)));
+ continue;
+ }
+ const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const Attribute *attr = nullptr;
+ if (namedAttr->constraint)
+ attr = &*namedAttr->constraint;
+ StringRef name = namedAttr->attrName;
+ if (name.empty())
+ report_fatal_error("missing name for property attr");
+ std::string camelName =
+ convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
+ // Generate the data member using the storage type.
+ StringRef storageType;
+ if (attr) {
+ storageType = attr->getStorageType();
+ } else {
+ if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
+ report_fatal_error("unexpected AttributeMetadata");
+ }
+ // TODO: update to use native integers.
+ storageType = "::mlir::DenseI32ArrayAttr";
+ }
+ os << " using " << name << "Ty = " << storageType << ";\n"
+ << " " << name << "Ty " << name << ";\n";
+
+ // Emit accessors using the interface type.
+ if (attr) {
+ const char *accessorFmt = R"decl(
+ auto get{0}() {
+ auto &propStorage = this->{1};
+ return propStorage.{2}<{3}>();
+ }
+ void set{0}(const {3} &propValue) {
+ this->{1} = propValue;
+ }
+)decl";
+ os << formatv(accessorFmt, camelName, name,
+ attr->isOptional() || attr->hasDefaultValue()
+ ? "dyn_cast_or_null"
+ : "cast",
+ storageType);
+ }
+ }
+ os << " };\n";
+ os.flush();
+ genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
+ }
genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
- genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
"odsOpName");
+ if (useProperties)
+ genericAdaptorBase.declare<Field>("Properties", "properties");
+ genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
genericAdaptor.addTemplateParam("RangeT");
genericAdaptor.addField("RangeT", "odsOperands");
@@ -3024,9 +3536,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
+ if (useProperties)
+ paramList.emplace_back("const Properties &", "properties", "{}");
+ else
+ paramList.emplace_back("::mlir::EmptyProperties", "properties", "{}");
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
baseConstructor->addMemberInitializer("odsAttrs", "attrs");
+ if (useProperties)
+ baseConstructor->addMemberInitializer("properties", "properties");
baseConstructor->addMemberInitializer("odsRegions", "regions");
MethodBody &body = baseConstructor->body();
@@ -3037,7 +3555,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
auto *constructor = genericAdaptor.addConstructor(std::move(paramList));
- constructor->addMemberInitializer("Base", "attrs, regions");
+ constructor->addMemberInitializer("Base", "attrs, properties, regions");
constructor->addMemberInitializer("odsOperands", "values");
}
@@ -3055,8 +3573,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
- // Any invalid overlap for `getOperands` will have been diagnosed before here
- // already.
+ // Any invalid overlap for `getOperands` will have been diagnosed before
+ // here already.
if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
m->body() << " return odsOperands;";
@@ -3070,8 +3588,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr");
ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
auto &body = method->body().indent();
- body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
- << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
+ if (!useProperties)
+ body << "assert(odsAttrs && \"no attributes when constructing "
+ "adapter\");\n";
+ body << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
attr.hasDefaultValue() || attr.isOptional()
? "dyn_cast_or_null"
: "cast",
@@ -3088,6 +3608,12 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
body << "return attr;\n";
};
+ if (useProperties) {
+ auto *m = genericAdaptorBase.addInlineMethod("const Properties &",
+ "getProperties");
+ ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
+ m->body() << " return properties;";
+ }
{
auto *m =
genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes");
@@ -3124,8 +3650,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
m->body() << formatv(" return *odsRegions[{0}];", i);
}
if (numRegions > 0) {
- // Any invalid overlap for `getRegions` will have been diagnosed before here
- // already.
+ // Any invalid overlap for `getRegions` will have been diagnosed before
+ // here already.
if (auto *m =
genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions"))
m->body() << " return odsRegions;";
@@ -3142,8 +3668,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
auto *constructor =
adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
constructor->addMemberInitializer(
- adaptor.getClassName(),
- "op->getOperands(), op->getAttrDictionary(), op->getRegions()");
+ adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), "
+ "op.getProperties(), op->getRegions()");
}
// Add verification function.
@@ -3159,10 +3685,12 @@ void OpOperandAdaptorEmitter::addVerification() {
MethodParameter("::mlir::Location", "loc"));
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
+ bool useProperties = emitHelper.hasProperties();
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
+ genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter,
+ useProperties);
body << " return ::mlir::success();";
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ed84bcc049a65..e0472926078d4 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -132,6 +132,14 @@ class AttrDictDirective
bool withKeyword;
};
+/// This class represents the `prop-dict` directive. This directive represents
+/// the properties of the operation, expressed as a directionary.
+class PropDictDirective
+ : public DirectiveElementBase<DirectiveElement::PropDict> {
+public:
+ explicit PropDictDirective() = default;
+};
+
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
@@ -294,8 +302,9 @@ struct OperationFormat {
};
OperationFormat(const Operator &op)
-
- {
+ : useProperties(op.getDialect().usePropertiesForAttributes() &&
+ !op.getAttributes().empty()),
+ opCppClassName(op.getCppClassName()) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
@@ -351,6 +360,12 @@ struct OperationFormat {
/// A flag indicating if this operation has the SingleBlock trait.
bool hasSingleBlockTrait;
+ /// Indicate whether attribute are stored in properties.
+ bool useProperties;
+
+ /// The Operation class name
+ StringRef opCppClassName;
+
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
@@ -389,8 +404,7 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
- if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
- result.attributes)) {{
+ if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
return ::mlir::failure();
}
)";
@@ -400,30 +414,29 @@ const char *const attrParserCode = R"(
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const genericAttrParserCode = R"(
- if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
+ if (parser.parseAttribute({0}Attr, {1}))
return ::mlir::failure();
)";
const char *const optionalAttrParserCode = R"(
- {
- ::mlir::OptionalParseResult parseResult =
- parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
- if (parseResult.has_value() && failed(*parseResult))
- return ::mlir::failure();
- }
+ ::mlir::OptionalParseResult parseResult{0}Attr =
+ parser.parseOptionalAttribute({0}Attr, {1});
+ if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
+ return ::mlir::failure();
+ if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
)";
/// The code snippet used to generate a parser call for a symbol name attribute.
///
/// {0}: The name of the attribute.
const char *const symbolNameAttrParserCode = R"(
- if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
+ if (parser.parseSymbolName({0}Attr))
return ::mlir::failure();
)";
const char *const optionalSymbolNameAttrParserCode = R"(
// Parsing an optional symbol name doesn't fail, so no need to check the
// result.
- (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
+ (void)parser.parseOptionalSymbolName({0}Attr);
)";
/// The code snippet used to generate a parser call for an enum attribute.
@@ -434,6 +447,7 @@ const char *const optionalSymbolNameAttrParserCode = R"(
/// {3}: The constant builder call to create an attribute of the enum type.
/// {4}: The set of allowed enum keywords.
/// {5}: The error message on failure when the enum isn't present.
+/// {6}: The attribute assignment expression
const char *const enumAttrParserCode = R"(
{
::llvm::StringRef attrStr;
@@ -460,7 +474,7 @@ const char *const enumAttrParserCode = R"(
<< "{0} attribute specification: \"" << attrStr << '"';;
{0}Attr = {3};
- result.addAttribute("{0}", {0}Attr);
+ {6}
}
}
)";
@@ -572,6 +586,7 @@ const char *const inferReturnTypesParserCode = R"(
if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
result.location, result.operands,
result.attributes.getDictionary(parser.getContext()),
+ result.getRawProperties(),
result.regions, inferredReturnTypes)))
return ::mlir::failure();
result.addTypes(inferredReturnTypes);
@@ -930,7 +945,9 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
}
/// Generate the parser for a custom directive.
-static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
+static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
+ bool useProperties,
+ StringRef opCppClassName) {
body << " {\n";
// Preprocess the directive variables.
@@ -1003,9 +1020,15 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
+ var->name, opCppClassName);
+ } else {
+ body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
+ var->name);
+ }
- body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
- var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) {
@@ -1041,7 +1064,8 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
/// Generate the parser for a enum attribute.
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
- FmtContext &attrTypeCtx, bool parseAsOptional) {
+ FmtContext &attrTypeCtx, bool parseAsOptional,
+ bool useProperties, StringRef opCppClassName) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1076,46 +1100,68 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
});
errorMessageOS << "]\");";
}
+ std::string attrAssignment;
+ if (useProperties) {
+ attrAssignment =
+ formatv(" "
+ "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
+ var->name, opCppClassName);
+ } else {
+ attrAssignment =
+ formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
+ }
body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr,
- validCaseKeywordsStr, errorMessage);
+ validCaseKeywordsStr, errorMessage, attrAssignment);
}
// Generate the parser for an attribute.
static void genAttrParser(AttributeVariable *attr, MethodBody &body,
- FmtContext &attrTypeCtx, bool parseAsOptional) {
+ FmtContext &attrTypeCtx, bool parseAsOptional,
+ bool useProperties, StringRef opCppClassName) {
const NamedAttribute *var = attr->getVar();
// Check to see if we can parse this as an enum attribute.
if (canFormatEnumAttr(var))
- return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional);
+ return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
+ useProperties, opCppClassName);
// Check to see if we should parse this as a symbol name attribute.
if (shouldFormatSymbolNameAttr(var)) {
body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
: symbolNameAttrParserCode,
var->name);
- return;
- }
-
- // If this attribute has a buildable type, use that when parsing the
- // attribute.
- std::string attrTypeStr;
- if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
- llvm::raw_string_ostream os(attrTypeStr);
- os << tgfmt(*typeBuilder, &attrTypeCtx);
} else {
- attrTypeStr = "::mlir::Type{}";
+
+ // If this attribute has a buildable type, use that when parsing the
+ // attribute.
+ std::string attrTypeStr;
+ if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+ llvm::raw_string_ostream os(attrTypeStr);
+ os << tgfmt(*typeBuilder, &attrTypeCtx);
+ } else {
+ attrTypeStr = "::mlir::Type{}";
+ }
+ if (parseAsOptional) {
+ body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ } else {
+ if (attr->shouldBeQualified() ||
+ var->attr.getStorageType() == "::mlir::Attribute")
+ body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+ else
+ body << formatv(attrParserCode, var->name, attrTypeStr);
+ }
}
- if (parseAsOptional) {
- body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ if (useProperties) {
+ body << formatv(
+ " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
+ "{0}Attr;\n",
+ var->name, opCppClassName);
} else {
- if (attr->shouldBeQualified() ||
- var->attr.getStorageType() == "::mlir::Attribute")
- body << formatv(genericAttrParserCode, var->name, attrTypeStr);
- else
- body << formatv(attrParserCode, var->name, attrTypeStr);
+ body << formatv(
+ " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
+ var->name);
}
}
@@ -1170,8 +1216,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
if (!thenGroup == optional->isInverted()) {
// Add the anchor unit attribute to the operation state.
- body << " result.addAttribute(\"" << anchorAttr->getVar()->name
- << "\", parser.getBuilder().getUnitAttr());\n";
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = "
+ "parser.getBuilder().getUnitAttr();",
+ anchorAttr->getVar()->name, opCppClassName);
+ } else {
+ body << " result.addAttribute(\"" << anchorAttr->getVar()->name
+ << "\", parser.getBuilder().getUnitAttr());\n";
+ }
}
}
@@ -1190,7 +1243,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
// parsing of the rest of the elements.
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
- genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true);
+ genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
+ useProperties, opCppClassName);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (::mlir::succeeded(parser.parseOptional";
@@ -1248,8 +1302,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << formatv(oilistParserCode, lelementName);
if (AttributeVariable *unitAttrElem =
oilist->getUnitAttrParsingElement(pelement)) {
- body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
- << "\", UnitAttr::get(parser.getContext()));\n";
+ if (useProperties) {
+ body << formatv(
+ " result.getOrAddProperties<{1}::Properties>().{0} = "
+ "parser.getBuilder().getUnitAttr();",
+ unitAttrElem->getVar()->name, opCppClassName);
+ } else {
+ body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
+ << "\", UnitAttr::get(parser.getContext()));\n";
+ }
} else {
for (FormatElement *el : pelement)
genElementParser(el, body, attrTypeCtx);
@@ -1275,7 +1336,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
bool parseAsOptional =
(genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
- genAttrParser(attr, body, attrTypeCtx, parseAsOptional);
+ genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
+ opCppClassName);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1311,13 +1373,27 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
/// Directives.
} else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
- body << " if (parser.parseOptionalAttrDict"
- << (attrDict->isWithKeyword() ? "WithKeyword" : "")
- << "(result.attributes))\n"
+ body.indent() << "{\n";
+ body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
+ << "if (parser.parseOptionalAttrDict"
+ << (attrDict->isWithKeyword() ? "WithKeyword" : "")
+ << "(result.attributes))\n"
+ << " return ::mlir::failure();\n";
+ if (useProperties) {
+ body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
+ "[&]() {\n"
+ << " return parser.emitError(loc) << \"'\" << "
+ "result.name.getStringRef() << \"' op \";\n"
+ << " })))\n"
+ << " return ::mlir::failure();\n";
+ }
+ body.unindent() << "}\n";
+ body.unindent();
+ } else if (auto *attrDict = dyn_cast<PropDictDirective>(element)) {
+ body << " if (parseProperties(parser, result))\n"
<< " return ::mlir::failure();\n";
} else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
- genCustomDirectiveParser(customDir, body);
-
+ genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " if (parser.parseOperandList(allOperands))\n"
@@ -1571,8 +1647,16 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- body << " result.addAttribute(\"operand_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(" "
+ "result.getOrAddProperties<{0}::Properties>().operand_"
+ "segment_sizes = "
+ "(parser.getBuilder().getDenseI32ArrayAttr({{",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"operand_segment_sizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ }
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
@@ -1586,18 +1670,36 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (!operand.isVariadicOfVariadic())
continue;
- body << llvm::formatv(
- " result.addAttribute(\"{0}\", "
- "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));\n",
- operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
- operand.name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << llvm::formatv(
+ " result.getOrAddProperties<{0}::Properties>().{1} = "
+ "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
+ op.getCppClassName(),
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
+ operand.name);
+ } else {
+ body << llvm::formatv(
+ " result.addAttribute(\"{0}\", "
+ "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
+ "\n",
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
+ operand.name);
+ }
}
}
if (!allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- body << " result.addAttribute(\"result_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << formatv(
+ " "
+ "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
+ "(parser.getBuilder().getDenseI32ArrayAttr({{",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"result_segment_sizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ }
auto interleaveFn = [&](const NamedTypeConstraint &result) {
// If the result is variadic emit the parsed size.
if (result.isVariableLength())
@@ -1641,6 +1743,14 @@ const char *enumAttrBeginPrinterCode = R"(
auto caseValueStr = {1}(caseValue);
)";
+/// Generate the printer for the 'prop-dict' directive.
+static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
+ MethodBody &body) {
+ body << " _odsPrinter << \" \";\n"
+ << " printProperties(this->getContext(), _odsPrinter, "
+ "getProperties());\n";
+}
+
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
MethodBody &body, bool withKeyword) {
@@ -1898,7 +2008,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
})
.Case<AttributeVariable>([&](AttributeVariable *element) {
Attribute attr = element->getVar()->attr;
- body << "(*this)->getAttr(\"" << element->getVar()->name << "\")";
+ body << op.getGetterName(element->getVar()->name) << "Attr()";
if (attr.isOptional())
return; // done
if (attr.hasDefaultValue()) {
@@ -1906,7 +2016,8 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
// default value.
FmtContext fctx;
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
- body << " != "
+ body << " && " << op.getGetterName(element->getVar()->name)
+ << "Attr() != "
<< tgfmt(attr.getConstBuilderTemplate(), &fctx,
attr.getDefaultValue());
return;
@@ -2063,6 +2174,13 @@ void OperationFormat::genElementPrinter(FormatElement *element,
return;
}
+ // Emit the attribute dictionary.
+ if (auto *propDict = dyn_cast<PropDictDirective>(element)) {
+ genPropDictPrinter(*this, op, body);
+ lastWasPunctuation = false;
+ return;
+ }
+
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
@@ -2300,6 +2418,7 @@ class OpFormatParser : public FormatParser {
ConstArgument findSeenArg(StringRef name);
/// Parse the various
diff erent directives.
+ FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
bool withKeyword);
FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
@@ -2329,6 +2448,7 @@ class OpFormatParser : public FormatParser {
// The following are various bits of format state used for verification
// during parsing.
bool hasAttrDict = false;
+ bool hasPropDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
bool canInferResultTypes = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
@@ -2873,6 +2993,8 @@ FailureOr<FormatElement *>
OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
Context ctx) {
switch (kind) {
+ case FormatToken::kw_prop_dict:
+ return parsePropDictDirective(loc, ctx);
case FormatToken::kw_attr_dict:
return parseAttrDictDirective(loc, ctx,
/*withKeyword=*/false);
@@ -2925,6 +3047,23 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
return create<AttrDictDirective>(withKeyword);
}
+FailureOr<FormatElement *>
+OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
+ if (context == TypeDirectiveContext)
+ return emitError(loc, "'prop-dict' directive can only be used as a "
+ "top-level directive");
+
+ if (context == RefDirectiveContext)
+ llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
+ // Otherwise, this is a top-level context.
+
+ if (hasPropDict)
+ return emitError(loc, "'prop-dict' directive has already been seen");
+ hasPropDict = true;
+
+ return create<PropDictDirective>();
+}
+
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *argument : arguments) {
diff --git a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
index e9cac3949ddbd..48c62ad20a04a 100644
--- a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
+++ b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
@@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context, Location loc,
context->allowUnregisteredDialects();
return Operation::create(loc, OperationName(operationName, context),
std::nullopt, std::nullopt, std::nullopt,
- std::nullopt, numRegions);
+ OpaqueProperties(nullptr), std::nullopt, numRegions);
}
namespace {
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index a3efb34889f41..fcd19addd8e3a 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -38,10 +38,10 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) {
// Using optional instead of plain int here to
diff erentiate absence of
// value from the value 0.
SmallVector<std::optional<int>> v = {0, 4};
- OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(
- v, builder.getDictionaryAttr({builder.getNamedAttr(
- "operand_segment_sizes",
- builder.getDenseI32ArrayAttr({1, 0, 1}))}));
+ OIListSimple::Properties prop;
+ prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1});
+ OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(v, {}, prop,
+ {});
EXPECT_EQ(d.getArg0(), 0);
EXPECT_EQ(d.getArg1(), std::nullopt);
EXPECT_EQ(d.getArg2(), 4);
@@ -51,9 +51,10 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) {
FormatVariadicOfVariadicOperand::FoldAdaptor e({});
{
SmallVector<int> v = {0, 1, 2, 3, 4};
- FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(
- v, builder.getDictionaryAttr({builder.getNamedAttr(
- "operand_segments", builder.getDenseI32ArrayAttr({3, 2, 0}))}));
+ FormatVariadicOfVariadicOperand::Properties prop;
+ prop.operand_segments = builder.getDenseI32ArrayAttr({3, 2, 0});
+ FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(v, {},
+ prop, {});
SmallVector<ArrayRef<int>> operand = f.getOperand();
ASSERT_EQ(operand.size(), (std::size_t)3);
EXPECT_THAT(operand[0], ElementsAre(0, 1, 2));
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 7d49283c59c3b..8a74a59096289 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_unittest(MLIRIRTests
PatternMatchTest.cpp
ShapedTypeTest.cpp
TypeTest.cpp
+ OpPropertiesTest.cpp
DEPENDS
MLIRTestInterfaceIncGen
diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
new file mode 100644
index 0000000000000..eda84a0e0b075
--- /dev/null
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -0,0 +1,358 @@
+//===- TestOpProperties.cpp - Test all properties-related APIs ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Parser/Parser.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+
+namespace {
+/// Simple structure definining a struct to define "properties" for a given
+/// operation. Default values are honored when creating an operation.
+struct TestProperties {
+ int a = -1;
+ float b = -1.;
+ std::vector<int64_t> array = {-33};
+ /// A shared_ptr to a const object is safe: it is equivalent to a value-based
+ /// member. Here the label will be deallocated when the last operation
+ /// referring to it is destroyed. However there is no pool-allocation: this is
+ /// offloaded to the client.
+ std::shared_ptr<const std::string> label;
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties)
+};
+
+/// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors
+/// through the provided diagnostic if any. This is used for example during
+/// parsing with the generic format.
+static LogicalResult
+setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
+ InFlightDiagnostic *diagnostic) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ if (diagnostic)
+ *diagnostic << "expected DictionaryAttr to set TestProperties";
+ return failure();
+ }
+ auto aAttr = dict.getAs<IntegerAttr>("a");
+ if (!aAttr) {
+ if (diagnostic)
+ *diagnostic << "expected IntegerAttr for key `a`";
+ return failure();
+ }
+ auto bAttr = dict.getAs<FloatAttr>("b");
+ if (!bAttr ||
+ &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
+ if (diagnostic)
+ *diagnostic << "expected FloatAttr for key `b`";
+ return failure();
+ }
+
+ auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
+ if (!arrayAttr) {
+ if (diagnostic)
+ *diagnostic << "expected DenseI64ArrayAttr for key `array`";
+ return failure();
+ }
+
+ auto label = dict.getAs<mlir::StringAttr>("label");
+ if (!label) {
+ if (diagnostic)
+ *diagnostic << "expected StringAttr for key `label`";
+ return failure();
+ }
+
+ prop.a = aAttr.getValue().getSExtValue();
+ prop.b = bAttr.getValue().convertToFloat();
+ prop.array.assign(arrayAttr.asArrayRef().begin(),
+ arrayAttr.asArrayRef().end());
+ prop.label = std::make_shared<std::string>(label.getValue());
+ return success();
+}
+
+/// Convert a TestProperties struct to a DictionaryAttr, this is used for
+/// example during printing with the generic format.
+static Attribute getPropertiesAsAttribute(MLIRContext *ctx,
+ const TestProperties &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("a", b.getI32IntegerAttr(prop.a)));
+ attrs.push_back(b.getNamedAttr("b", b.getF32FloatAttr(prop.b)));
+ attrs.push_back(b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array)));
+ attrs.push_back(b.getNamedAttr(
+ "label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>")));
+ return b.getDictionaryAttr(attrs);
+}
+
+inline llvm::hash_code computeHash(const TestProperties &prop) {
+ // We hash `b` which is a float using its underlying array of char:
+ unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b);
+ ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)};
+ return llvm::hash_combine(
+ prop.a, llvm::hash_combine_range(bBytes.begin(), bBytes.end()),
+ llvm::hash_combine_range(prop.array.begin(), prop.array.end()),
+ StringRef(*prop.label));
+}
+
+/// A custom operation for the purpose of showcasing how to use "properties".
+class OpWithProperties : public Op<OpWithProperties> {
+public:
+ // Begin boilerplate
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties)
+ using Op::Op;
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+ static StringRef getOperationName() {
+ return "test_op_properties.op_with_properties";
+ }
+ // End boilerplate
+
+ // This alias is the only definition needed for enabling "properties" for this
+ // operation.
+ using Properties = TestProperties;
+ static std::optional<mlir::Attribute> getInherentAttr(const Properties &prop,
+ StringRef name) {
+ return std::nullopt;
+ }
+ static void setInherentAttr(Properties &prop, StringRef name,
+ mlir::Attribute value) {}
+ static void populateInherentAttrs(const Properties &prop,
+ NamedAttrList &attrs) {}
+ static LogicalResult
+ verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
+ function_ref<InFlightDiagnostic()> getDiag) {
+ return success();
+ }
+};
+
+// A trivial supporting dialect to register the above operation.
+class TestOpPropertiesDialect : public Dialect {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect)
+ static constexpr StringLiteral getDialectNamespace() {
+ return StringLiteral("test_op_properties");
+ }
+ explicit TestOpPropertiesDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context,
+ TypeID::get<TestOpPropertiesDialect>()) {
+ addOperations<OpWithProperties>();
+ }
+};
+
+constexpr StringLiteral mlirSrc = R"mlir(
+ "test_op_properties.op_with_properties"()
+ <{a = -42 : i32,
+ b = -4.200000e+01 : f32,
+ array = array<i64: 40, 41>,
+ label = "bar foo"}> : () -> ()
+)mlir";
+
+TEST(OpPropertiesTest, Properties) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ ParserConfig config(&context);
+ // Parse the operation with some properties.
+ OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
+ ASSERT_TRUE(op.get() != nullptr);
+ auto opWithProp = dyn_cast<OpWithProperties>(op.get());
+ ASSERT_TRUE(opWithProp);
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ opWithProp.print(os);
+ ASSERT_STREQ("\"test_op_properties.op_with_properties\"() "
+ "<{a = -42 : i32, "
+ "array = array<i64: 40, 41>, "
+ "b = -4.200000e+01 : f32, "
+ "label = \"bar foo\"}> : () -> ()\n",
+ os.str().c_str());
+ }
+ // Get a mutable reference to the properties for this operation and modify it
+ // in place one member at a time.
+ TestProperties &prop = opWithProp.getProperties();
+ prop.a = 42;
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ opWithProp.print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = -4.200000e+01"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
+ EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
+ }
+ prop.b = 42.;
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ opWithProp.print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
+ EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
+ }
+ prop.array.push_back(42);
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ opWithProp.print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
+ EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
+ }
+ prop.label = std::make_shared<std::string>("foo bar");
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ opWithProp.print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
+ EXPECT_TRUE(StringRef(os.str()).contains("label = \"foo bar\""));
+ }
+}
+
+// Test diagnostic emission when using invalid dictionary.
+TEST(OpPropertiesTest, FailedProperties) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ std::string diagnosticStr;
+ context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
+ diagnosticStr += diag.str();
+ return success();
+ });
+
+ // Parse the operation with some properties.
+ ParserConfig config(&context);
+
+ // Parse an operation with invalid (incomplete) properties.
+ OwningOpRef<Operation *> owningOp =
+ parseSourceString("\"test_op_properties.op_with_properties\"() "
+ "<{a = -42 : i32}> : () -> ()\n",
+ config);
+ ASSERT_EQ(owningOp.get(), nullptr);
+ EXPECT_STREQ(
+ "invalid properties {a = -42 : i32} for op "
+ "test_op_properties.op_with_properties: expected FloatAttr for key `b`",
+ diagnosticStr.c_str());
+ diagnosticStr.clear();
+
+ owningOp = parseSourceString(mlirSrc, config);
+ Operation *op = owningOp.get();
+ ASSERT_TRUE(op != nullptr);
+ Location loc = op->getLoc();
+ auto opWithProp = dyn_cast<OpWithProperties>(op);
+ ASSERT_TRUE(opWithProp);
+
+ OperationState state(loc, op->getName());
+ Builder b{&context};
+ NamedAttrList attrs;
+ attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo")));
+ state.propertiesAttr = attrs.getDictionary(&context);
+ {
+ auto diag = op->emitError("setting properties failed: ");
+ auto result = state.setProperties(op, &diag);
+ EXPECT_TRUE(result.failed());
+ }
+ EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",
+ diagnosticStr.c_str());
+}
+
+TEST(OpPropertiesTest, DefaultValues) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ OperationState state(UnknownLoc::get(&context),
+ "test_op_properties.op_with_properties");
+ Operation *op = Operation::create(state);
+ ASSERT_TRUE(op != nullptr);
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ op->print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = -1"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = -1"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: -33>"));
+ }
+ op->erase();
+}
+
+TEST(OpPropertiesTest, Cloning) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ ParserConfig config(&context);
+ // Parse the operation with some properties.
+ OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
+ ASSERT_TRUE(op.get() != nullptr);
+ auto opWithProp = dyn_cast<OpWithProperties>(op.get());
+ ASSERT_TRUE(opWithProp);
+ Operation *clone = opWithProp->clone();
+
+ // Check that op and its clone prints equally
+ std::string opStr;
+ std::string cloneStr;
+ {
+ llvm::raw_string_ostream os(opStr);
+ op.get()->print(os);
+ }
+ {
+ llvm::raw_string_ostream os(cloneStr);
+ clone->print(os);
+ }
+ clone->erase();
+ EXPECT_STREQ(opStr.c_str(), cloneStr.c_str());
+}
+
+TEST(OpPropertiesTest, Equivalence) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ ParserConfig config(&context);
+ // Parse the operation with some properties.
+ OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
+ ASSERT_TRUE(op.get() != nullptr);
+ auto opWithProp = dyn_cast<OpWithProperties>(op.get());
+ ASSERT_TRUE(opWithProp);
+ llvm::hash_code reference = OperationEquivalence::computeHash(opWithProp);
+ TestProperties &prop = opWithProp.getProperties();
+ prop.a = 42;
+ EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
+ prop.a = -42;
+ EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
+ prop.b = 42.;
+ EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
+ prop.b = -42.;
+ EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
+ prop.array.push_back(42);
+ EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
+ prop.array.pop_back();
+ EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
+}
+
+TEST(OpPropertiesTest, getOrAddProperties) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ OperationState state(UnknownLoc::get(&context),
+ "test_op_properties.op_with_properties");
+ // Test `getOrAddProperties` API on OperationState.
+ TestProperties &prop = state.getOrAddProperties<TestProperties>();
+ prop.a = 1;
+ prop.b = 2;
+ prop.array = {3, 4, 5};
+ Operation *op = Operation::create(state);
+ ASSERT_TRUE(op != nullptr);
+ {
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ op->print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("a = 1"));
+ EXPECT_TRUE(StringRef(os.str()).contains("b = 2"));
+ EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 3, 4, 5>"));
+ }
+ op->erase();
+}
+
+} // namespace
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 03cdb4b3776f8..81340d3bfdf70 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -22,9 +22,9 @@ static Operation *createOp(MLIRContext *context,
ArrayRef<Type> resultTypes = std::nullopt,
unsigned int numRegions = 0) {
context->allowUnregisteredDialects();
- return Operation::create(UnknownLoc::get(context),
- OperationName("foo.bar", context), resultTypes,
- operands, std::nullopt, std::nullopt, numRegions);
+ return Operation::create(
+ UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes,
+ operands, std::nullopt, nullptr, std::nullopt, numRegions);
}
namespace {
diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp
index f4a60af82d126..10d7fb041278d 100644
--- a/mlir/unittests/Transforms/DialectConversion.cpp
+++ b/mlir/unittests/Transforms/DialectConversion.cpp
@@ -13,9 +13,9 @@ using namespace mlir;
static Operation *createOp(MLIRContext *context) {
context->allowUnregisteredDialects();
- return Operation::create(UnknownLoc::get(context),
- OperationName("foo.bar", context), std::nullopt,
- std::nullopt, std::nullopt, std::nullopt, 0);
+ return Operation::create(
+ UnknownLoc::get(context), OperationName("foo.bar", context), std::nullopt,
+ std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0);
}
namespace {
More information about the Mlir-commits
mailing list