[Mlir-commits] [mlir] 4e103a1 - [mlir] Add support for VariadicOfVariadic operands
River Riddle
llvmlistbot at llvm.org
Mon Aug 23 13:42:15 PDT 2021
Author: River Riddle
Date: 2021-08-23T20:32:31Z
New Revision: 4e103a12d9d6d03ad4147a0c9a8f5742538eefec
URL: https://github.com/llvm/llvm-project/commit/4e103a12d9d6d03ad4147a0c9a8f5742538eefec
DIFF: https://github.com/llvm/llvm-project/commit/4e103a12d9d6d03ad4147a0c9a8f5742538eefec.diff
LOG: [mlir] Add support for VariadicOfVariadic operands
This revision adds native ODS support for VariadicOfVariadic operand
groups. An example of this is the SwitchOp, which has a variadic number
of nested operand ranges for each of the case statements, where the
number of case statements is variadic. Builtin ODS support allows for
generating proper accessors for the nested operand ranges, builder
support, and declarative format support. VariadicOfVariadic operands
are supported by providing a segment attribute to use to store the
operand groups, mapping similarly to the AttrSizedOperand trait
(but with a user defined attribute name).
`build` methods for VariadicOfVariadic operand expect inputs of the
form `ArrayRef<ValueRange>`. Accessors for the variadic ranges
return a new `OperandRangeRange` type, which represents a
contiguous range of `OperandRange`. In the declarative assembly
format, VariadicOfVariadic operands and types are by default
formatted as a comma delimited list of value lists:
`(<value>, <value>), (), (<value>)`.
Differential Revision: https://reviews.llvm.org/D107774
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/TypeRange.h
mlir/include/mlir/TableGen/Argument.h
mlir/include/mlir/TableGen/Type.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/TableGen/Argument.cpp
mlir/lib/TableGen/Operator.cpp
mlir/lib/TableGen/Type.cpp
mlir/test/IR/traits.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 0e85d32916c64..52c66b94e5e0b 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -229,6 +229,17 @@ the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.
+#### VariadicOfVariadic operands
+
+To declare a variadic operand that has a variadic number of sub-ranges, wrap the
+`TypeConstraint` for the operand with `VariadicOfVariadic<...,
+"<segment-attribute-name>">`.
+
+The second field of the `VariadicOfVariadic` is the name of an `I32ElementsAttr`
+argument that contains the sizes of the variadic sub-ranges. This attribute will
+be used when determining the size of sub-ranges, or when updating the size of
+sub-ranges.
+
#### Optional operands
To declare an optional operand, wrap the `TypeConstraint` for the operand with
@@ -717,6 +728,8 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `OpAsmParser::OperandType &`
- Optional: `Optional<OpAsmParser::OperandType> &`
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+ - VariadicOfVariadic:
+ `SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &`
* Ref Directives
- A reference directive is passed to the parser using the same mapping as
the input operand. For example, a single region would be passed as a
@@ -731,6 +744,7 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `Type &`
- Optional: `Type &`
- Variadic: `SmallVectorImpl<Type> &`
+ - VariadicOfVariadic: `SmallVectorImpl<SmallVector<Type>> &`
* `attr-dict` Directive: `NamedAttrList &`
When a variable is optional, the value should only be specified if the variable
@@ -749,6 +763,7 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Value`
- Optional: `Value`
- Variadic: `OperandRange`
+ - VariadicOfVariadic: `OperandRangeRange`
* Ref Directives
- A reference directive is passed to the printer using the same mapping as
the input operand. For example, a single region would be passed as a
@@ -763,6 +778,7 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Type`
- Optional: `Type`
- Variadic: `TypeRange`
+ - VariadicOfVariadic: `TypeRangeRange`
* `attr-dict` Directive: `DictionaryAttr`
When a variable is optional, the provided value may be null.
@@ -923,7 +939,7 @@ be defined.
When this boolean field is set to `true`, it indicates that the op implements a
`canonicalize` method for simple "matchAndRewrite" style canonicalization
-patterns. If `hasCanonicalizer` is 0, then an implementation of
+patterns. If `hasCanonicalizer` is 0, then an implementation of
`::getCanonicalizationPatterns()` is implemented to call this function.
### `hasFolder`
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6d09c50c92a8c..2b05ae1858146 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -701,23 +701,25 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
NoSideEffect]> {
- let arguments = (ins I32:$value,
- Variadic<AnyType>:$defaultOperands,
- Variadic<AnyType>:$caseOperands,
- OptionalAttr<ElementsAttr>:$case_values,
- OptionalAttr<ElementsAttr>:$case_operand_offsets,
- OptionalAttr<ElementsAttr>:$branch_weights);
+ let arguments = (ins
+ I32:$value,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+ OptionalAttr<ElementsAttr>:$case_values,
+ ElementsAttr:$case_operand_segments,
+ OptionalAttr<ElementsAttr>:$branch_weights
+ );
let successors = (successor
- AnySuccessor:$defaultDestination,
- VariadicSuccessor<AnySuccessor>:$caseDestinations);
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations
+ );
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
$value `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
`[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
- $caseOperands, type($caseOperands),
- $case_operand_offsets) `]`
+ $caseOperands, type($caseOperands)) `]`
attr-dict
}];
@@ -734,11 +736,15 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
- OperandRange getCaseOperands(unsigned index);
+ OperandRange getCaseOperands(unsigned index) {
+ return caseOperands()[index];
+ }
/// Return a mutable range of operands for the case destination block at the
/// given index.
- MutableOperandRange getCaseOperandsMutable(unsigned index);
+ MutableOperandRange getCaseOperandsMutable(unsigned index) {
+ return caseOperandsMutable()[index];
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7f54dc5da2f88..b715070706e37 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1812,14 +1812,17 @@ def SwitchOp : Std_Op<"switch",
```
}];
- let arguments = (ins AnyInteger:$flag,
- Variadic<AnyType>:$defaultOperands,
- Variadic<AnyType>:$caseOperands,
- OptionalAttr<AnyIntElementsAttr>:$case_values,
- OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
+ let arguments = (ins
+ AnyInteger:$flag,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+ OptionalAttr<AnyIntElementsAttr>:$case_values,
+ I32ElementsAttr:$case_operand_segments
+ );
let successors = (successor
- AnySuccessor:$defaultDestination,
- VariadicSuccessor<AnySuccessor>:$caseDestinations);
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations
+ );
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
@@ -1849,19 +1852,22 @@ def SwitchOp : Std_Op<"switch",
$case_values,
$caseDestinations,
$caseOperands,
- type($caseOperands),
- $case_operand_offsets)
+ type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
- OperandRange getCaseOperands(unsigned index);
+ OperandRange getCaseOperands(unsigned index) {
+ return caseOperands()[index];
+ }
/// Return a mutable range of operands for the case destination block at the
/// given index.
- MutableOperandRange getCaseOperandsMutable(unsigned index);
+ MutableOperandRange getCaseOperandsMutable(unsigned index) {
+ return caseOperandsMutable()[index];
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index c24c05b877cf3..59088aead5eba 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -324,6 +324,16 @@ class Variadic<Type type> : TypeConstraint<type.predicate, type.summary> {
Type baseType = type;
}
+// A nested variadic type constraint. It expands to zero or more variadic ranges
+// of the base type. This class is used for supporting variadic operands and
+// results. `variadicSegmentAttrName` should correspond to the name of an
+// I32ElementsAttr argument that provides the sizes of the inner variadic
+// operand groups.
+class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
+ : Variadic<type> {
+ string segmentAttrName = variadicSegmentAttrName;
+}
+
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary> {
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index e7a4794bd5f49..5972568b18eae 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -267,6 +267,9 @@ LogicalResult verifyZeroSuccessor(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
+LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
+ StringRef valueGroupName,
+ size_t expectedCount);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 111e2b8d0bc0c..0af719c92d911 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -36,12 +36,14 @@ namespace mlir {
class Dialect;
class DictionaryAttr;
class ElementsAttr;
+class MutableOperandRangeRange;
class Operation;
struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class OperandRange;
+class OperandRangeRange;
class OpFoldResult;
class ParseResult;
class Pattern;
@@ -727,6 +729,10 @@ class OperandRange final : public llvm::detail::indexed_accessor_range_base<
/// must not be empty.
unsigned getBeginOperandIndex() const;
+ /// Split this range into a set of contiguous subranges using the given
+ /// elements attribute, which contains the sizes of the sub ranges.
+ OperandRangeRange split(ElementsAttr segmentSizes) const;
+
private:
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptr
diff _t index) {
@@ -741,6 +747,42 @@ class OperandRange final : public llvm::detail::indexed_accessor_range_base<
friend RangeBaseT;
};
+//===----------------------------------------------------------------------===//
+// OperandRangeRange
+
+/// This class represents a contiguous range of operand ranges, e.g. from a
+/// VariadicOfVariadic operand group.
+class OperandRangeRange final
+ : public llvm::indexed_accessor_range<
+ OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange,
+ OperandRange, OperandRange> {
+ using OwnerT = std::pair<OpOperand *, Attribute>;
+ using RangeBaseT =
+ llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange,
+ OperandRange, OperandRange>;
+
+public:
+ using RangeBaseT::RangeBaseT;
+
+ /// Returns the range of types of the values within this range.
+ TypeRangeRange getTypes() const { return TypeRangeRange(*this); }
+ auto getType() const { return getTypes(); }
+
+ /// Construct a range given a parent set of operands, and an I32 elements
+ /// attribute containing the sizes of the sub ranges.
+ OperandRangeRange(OperandRange operands, Attribute operandSegments);
+
+ /// Flatten all of the sub ranges into a single contiguous operand range.
+ OperandRange join() const;
+
+private:
+ /// See `llvm::indexed_accessor_range` for details.
+ static OperandRange dereference(const OwnerT &object, ptr
diff _t index);
+
+ /// Allow access to `dereference_iterator`.
+ friend RangeBaseT;
+};
+
//===----------------------------------------------------------------------===//
// MutableOperandRange
@@ -761,8 +803,9 @@ class MutableOperandRange {
MutableOperandRange(Operation *owner);
/// Slice this range into a sub range, with the additional operand segment.
- MutableOperandRange slice(unsigned subStart, unsigned subLen,
- Optional<OperandSegment> segment = llvm::None);
+ MutableOperandRange
+ slice(unsigned subStart, unsigned subLen,
+ Optional<OperandSegment> segment = llvm::None) const;
/// Append the given values to the range.
void append(ValueRange values);
@@ -782,12 +825,19 @@ class MutableOperandRange {
/// Returns the current size of the range.
unsigned size() const { return length; }
+ /// Returns if the current range is empty.
+ bool empty() const { return size() == 0; }
+
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
/// Returns the owning operation.
Operation *getOwner() const { return owner; }
+ /// Split this range into a set of contiguous subranges using the given
+ /// elements attribute, which contains the sizes of the sub ranges.
+ MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
+
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
@@ -801,7 +851,46 @@ class MutableOperandRange {
/// Optional set of operand segments that should be updated when mutating the
/// length of this range.
- SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
+ SmallVector<OperandSegment, 1> operandSegments;
+};
+
+//===----------------------------------------------------------------------===//
+// MutableOperandRangeRange
+
+/// This class represents a contiguous range of mutable operand ranges, e.g.
+/// from a VariadicOfVariadic operand group.
+class MutableOperandRangeRange final
+ : public llvm::indexed_accessor_range<
+ MutableOperandRangeRange,
+ std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange,
+ MutableOperandRange, MutableOperandRange> {
+ using OwnerT = std::pair<MutableOperandRange, NamedAttribute>;
+ using RangeBaseT =
+ llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT,
+ MutableOperandRange, MutableOperandRange,
+ MutableOperandRange>;
+
+public:
+ using RangeBaseT::RangeBaseT;
+
+ /// Construct a range given a parent set of operands, and an I32 tensor
+ /// elements attribute containing the sizes of the sub ranges.
+ MutableOperandRangeRange(const MutableOperandRange &operands,
+ NamedAttribute operandSegmentAttr);
+
+ /// Flatten all of the sub ranges into a single contiguous mutable operand
+ /// range.
+ MutableOperandRange join() const;
+
+ /// Allow implicit conversion to an OperandRangeRange.
+ operator OperandRangeRange() const;
+
+private:
+ /// See `llvm::indexed_accessor_range` for details.
+ static MutableOperandRange dereference(const OwnerT &object, ptr
diff _t index);
+
+ /// Allow access to `dereference_iterator`.
+ friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index 4fb40e127f9fa..952a7d71211d1 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -16,6 +16,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/Sequence.h"
namespace mlir {
class OperandRange;
@@ -88,6 +89,35 @@ inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
return os;
}
+//===----------------------------------------------------------------------===//
+// TypeRangeRange
+
+using TypeRangeRangeIterator =
+ llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
+ std::function<TypeRange(unsigned)>>;
+
+/// This class provides an abstraction for a range of TypeRange. This is useful
+/// when accessing the types of a range of ranges, such as when using
+/// OperandRangeRange.
+class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
+public:
+ template <typename RangeT>
+ TypeRangeRange(const RangeT &range)
+ : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
+
+private:
+ template <typename RangeT>
+ TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
+ : llvm::iterator_range<TypeRangeRangeIterator>(
+ {sizeRange.begin(), getRangeFn(range)},
+ {sizeRange.end(), nullptr}) {}
+
+ template <typename RangeT>
+ static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
+ return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
+ }
+};
+
//===----------------------------------------------------------------------===//
// ValueTypeRange
diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h
index 0eb4d8ce41983..1d89a76c2924c 100644
--- a/mlir/include/mlir/TableGen/Argument.h
+++ b/mlir/include/mlir/TableGen/Argument.h
@@ -48,6 +48,8 @@ struct NamedTypeConstraint {
bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
+ // Returns true if this operand/result is a variadic of a variadic constraint.
+ bool isVariadicOfVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 6af6d05076a2f..c996adabdcff1 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -40,6 +40,13 @@ class TypeConstraint : public Constraint {
// Returns true if this is a variadic type constraint.
bool isVariadic() const;
+ // Returns true if this is a nested variadic type constraint.
+ bool isVariadicOfVariadic() const;
+
+ // Return the segment size attribute used if this is a variadic of variadic
+ // constraint. Asserts isVariadicOfVariadic() is true.
+ StringRef getVariadicOfVariadicSegmentSizeAttr() const;
+
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 54178cc1bec08..8626efdfe4a97 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -520,7 +520,7 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
/*caseDestinations=*/caseDest,
- /*caseOperands=*/ArrayRef<ValueRange>(),
+ /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
/*branchWeights=*/ArrayRef<int32_t>());
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index bb71eff459ae5..ca66e7f08af97 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -32,6 +32,7 @@
#include "llvm/Support/SourceMgr.h"
#include <iostream>
+#include <numeric>
using namespace mlir;
using namespace mlir::LLVM;
@@ -235,41 +236,27 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
- SmallVector<Value> flattenedCaseOperands;
- SmallVector<int32_t> caseOperandOffsets;
- int32_t offset = 0;
- for (ValueRange operands : caseOperands) {
- flattenedCaseOperands.append(operands.begin(), operands.end());
- caseOperandOffsets.push_back(offset);
- offset += operands.size();
- }
ElementsAttr caseValuesAttr;
if (!caseValues.empty())
caseValuesAttr = builder.getI32VectorAttr(caseValues);
- ElementsAttr caseOperandOffsetsAttr;
- if (!caseOperandOffsets.empty())
- caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
ElementsAttr weightsAttr;
if (!branchWeights.empty())
weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
- build(builder, result, value, defaultOperands, flattenedCaseOperands,
- caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
- caseDestinations);
+ build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
+ weightsAttr, defaultDestination, caseDestinations);
}
/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
-static ParseResult
-parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
- SmallVectorImpl<Block *> &caseDestinations,
- SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
- SmallVectorImpl<Type> &caseOperandTypes,
- ElementsAttr &caseOperandOffsets) {
+static ParseResult parseSwitchOpCases(
+ OpAsmParser &parser, ElementsAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
+ SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
SmallVector<int32_t> values;
- SmallVector<int32_t> offsets;
- int32_t value, offset = 0;
+ int32_t value = 0;
do {
OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
if (values.empty() && !integerParseResult.hasValue())
@@ -281,32 +268,28 @@ parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
+ SmallVector<Type> operandTypes;
if (parser.parseColon() || parser.parseSuccessor(destination))
return failure();
if (!parser.parseOptionalLParen()) {
if (parser.parseRegionArgumentList(operands) ||
- parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
+ parser.parseColonTypeList(operandTypes) || parser.parseRParen())
return failure();
}
caseDestinations.push_back(destination);
- caseOperands.append(operands.begin(), operands.end());
- offsets.push_back(offset);
- offset += operands.size();
+ caseOperands.emplace_back(operands);
+ caseOperandTypes.emplace_back(operandTypes);
} while (!parser.parseOptionalComma());
- Builder &builder = parser.getBuilder();
- caseValues = builder.getI32VectorAttr(values);
- caseOperandOffsets = builder.getI32VectorAttr(offsets);
-
+ caseValues = parser.getBuilder().getI32VectorAttr(values);
return success();
}
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
ElementsAttr caseValues,
SuccessorRange caseDestinations,
- OperandRange caseOperands,
- TypeRange caseOperandTypes,
- ElementsAttr caseOperandOffsets) {
+ OperandRangeRange caseOperands,
+ TypeRangeRange caseOperandTypes) {
if (!caseValues)
return;
@@ -317,7 +300,7 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
p << " ";
p << std::get<0>(i).getLimitedValue();
p << ": ";
- p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
+ p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
},
[&] {
p << ',';
@@ -341,28 +324,6 @@ static LogicalResult verify(SwitchOp op) {
return success();
}
-OperandRange SwitchOp::getCaseOperands(unsigned index) {
- return getCaseOperandsMutable(index);
-}
-
-MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
- MutableOperandRange caseOperands = caseOperandsMutable();
- if (!case_operand_offsets()) {
- assert(caseOperands.size() == 0 &&
- "non-empty case operands must have offsets");
- return caseOperands;
- }
-
- ElementsAttr offsets = case_operand_offsets().getValue();
- assert(index < offsets.size() && "invalid case operand offset index");
-
- int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
- int64_t end = index + 1 == offsets.size()
- ? caseOperands.size()
- : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
- return caseOperandsMutable().slice(begin, end - begin);
-}
-
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0c4b1441dd4cd..a84d8a513a556 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -28,6 +28,7 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+#include <numeric>
#include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
@@ -2130,21 +2131,8 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
- SmallVector<Value> flattenedCaseOperands;
- SmallVector<int32_t> caseOperandOffsets;
- int32_t offset = 0;
- for (ValueRange operands : caseOperands) {
- flattenedCaseOperands.append(operands.begin(), operands.end());
- caseOperandOffsets.push_back(offset);
- offset += operands.size();
- }
- DenseIntElementsAttr caseOperandOffsetsAttr;
- if (!caseOperandOffsets.empty())
- caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
-
- build(builder, result, value, defaultOperands, flattenedCaseOperands,
- caseValues, caseOperandOffsetsAttr, defaultDestination,
- caseDestinations);
+ build(builder, result, value, defaultOperands, caseOperands, caseValues,
+ defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
@@ -2163,16 +2151,14 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
-static ParseResult
-parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
- Block *&defaultDestination,
- SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
- SmallVectorImpl<Type> &defaultOperandTypes,
- DenseIntElementsAttr &caseValues,
- SmallVectorImpl<Block *> &caseDestinations,
- SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
- SmallVectorImpl<Type> &caseOperandTypes,
- DenseIntElementsAttr &caseOperandOffsets) {
+static ParseResult parseSwitchOpCases(
+ OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
+ SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
+ SmallVectorImpl<Type> &defaultOperandTypes,
+ DenseIntElementsAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
+ SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
failed(parser.parseSuccessor(defaultDestination)))
return failure();
@@ -2184,9 +2170,7 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
}
SmallVector<APInt> values;
- SmallVector<int32_t> offsets;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
- int64_t offset = 0;
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
@@ -2195,30 +2179,26 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
+ SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
- failed(parser.parseColonTypeList(caseOperandTypes)) ||
+ failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
- caseOperands.append(operands.begin(), operands.end());
- offsets.push_back(offset);
- offset += operands.size();
+ caseOperands.emplace_back(operands);
+ caseOperandTypes.emplace_back(operandTypes);
}
- if (values.empty())
- return success();
-
- Builder &builder = parser.getBuilder();
- ShapedType caseValueType =
- VectorType::get(static_cast<int64_t>(values.size()), flagType);
- caseValues = DenseIntElementsAttr::get(caseValueType, values);
- caseOperandOffsets = builder.getI32VectorAttr(offsets);
-
+ if (!values.empty()) {
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(values.size()), flagType);
+ caseValues = DenseIntElementsAttr::get(caseValueType, values);
+ }
return success();
}
@@ -2226,8 +2206,7 @@ static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
- OperandRange caseOperands, TypeRange caseOperandTypes,
- ElementsAttr caseOperandOffsets) {
+ OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
@@ -2240,7 +2219,7 @@ static void printSwitchOpCases(
p << " ";
p << caseValues.getValue<APInt>(i).getLimitedValue();
p << ": ";
- p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
+ p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
}
p.printNewline();
}
@@ -2268,28 +2247,6 @@ static LogicalResult verify(SwitchOp op) {
return success();
}
-OperandRange SwitchOp::getCaseOperands(unsigned index) {
- return getCaseOperandsMutable(index);
-}
-
-MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
- MutableOperandRange caseOperands = caseOperandsMutable();
- if (!case_operand_offsets()) {
- assert(caseOperands.size() == 0 &&
- "non-empty case operands must have offsets");
- return caseOperands;
- }
-
- ElementsAttr offsets = case_operand_offsets().getValue();
- assert(index < offsets.size() && "invalid case operand offset index");
-
- int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
- int64_t end = index + 1 == offsets.size()
- ? caseOperands.size()
- : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
- return caseOperandsMutable().slice(begin, end - begin);
-}
-
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 749a115f4491a..d5b8f1c2903e3 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -996,16 +996,19 @@ OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
return success();
}
-static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
- bool isOperand) {
+LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
+ StringRef attrName,
+ StringRef valueGroupName,
+ size_t expectedCount) {
auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName);
if (!sizeAttr)
- return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
+ return op->emitOpError("requires 1D i32 elements attribute '")
+ << attrName << "'";
- auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>();
- if (!sizeAttrType || sizeAttrType.getRank() != 1 ||
+ auto sizeAttrType = sizeAttr.getType();
+ if (sizeAttrType.getRank() != 1 ||
!sizeAttrType.getElementType().isInteger(32))
- return op->emitOpError("requires 1D vector of i32 attribute '")
+ return op->emitOpError("requires 1D i32 elements attribute '")
<< attrName << "'";
if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
@@ -1018,25 +1021,22 @@ static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
sizeAttr.begin(), sizeAttr.end(), 0,
[](unsigned all, APInt one) { return all + one.getZExtValue(); });
- if (isOperand && totalCount != op->getNumOperands())
- return op->emitOpError("operand count (")
- << op->getNumOperands() << ") does not match with the total size ("
- << totalCount << ") specified in attribute '" << attrName << "'";
- else if (!isOperand && totalCount != op->getNumResults())
- return op->emitOpError("result count (")
- << op->getNumResults() << ") does not match with the total size ("
- << totalCount << ") specified in attribute '" << attrName << "'";
+ if (totalCount != expectedCount)
+ return op->emitOpError()
+ << valueGroupName << " count (" << expectedCount
+ << ") does not match with the total size (" << totalCount
+ << ") specified in attribute '" << attrName << "'";
return success();
}
LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
StringRef attrName) {
- return verifyValueSizeAttr(op, attrName, /*isOperand=*/true);
+ return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands());
}
LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
StringRef attrName) {
- return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
+ return verifyValueSizeAttr(op, attrName, "result", op->getNumResults());
}
LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index bb9a5603f8714..ca4debe21f8e1 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/BitVector.h"
+#include <numeric>
using namespace mlir;
@@ -394,13 +396,38 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
OperandRange::OperandRange(Operation *op)
: OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
-/// Return the operand index of the first element of this range. The range
-/// must not be empty.
unsigned OperandRange::getBeginOperandIndex() const {
assert(!empty() && "range must not be empty");
return base->getOperandNumber();
}
+OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const {
+ return OperandRangeRange(*this, segmentSizes);
+}
+
+//===----------------------------------------------------------------------===//
+// OperandRangeRange
+
+OperandRangeRange::OperandRangeRange(OperandRange operands,
+ Attribute operandSegments)
+ : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
+ operandSegments.cast<DenseElementsAttr>().size()) {}
+
+OperandRange OperandRangeRange::join() const {
+ const OwnerT &owner = getBase();
+ auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>();
+ return OperandRange(owner.first,
+ std::accumulate(sizeData.begin(), sizeData.end(), 0));
+}
+
+OperandRange OperandRangeRange::dereference(const OwnerT &object,
+ ptr
diff _t index) {
+ auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>();
+ uint32_t startIndex =
+ std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
+ return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
+}
+
//===----------------------------------------------------------------------===//
// MutableOperandRange
@@ -419,7 +446,7 @@ MutableOperandRange::MutableOperandRange(Operation *owner)
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
- Optional<OperandSegment> segment) {
+ Optional<OperandSegment> segment) const {
assert((subStart + subLen) <= length && "invalid sub-range");
MutableOperandRange subSlice(owner, start + subStart, subLen,
operandSegments);
@@ -475,6 +502,11 @@ MutableOperandRange::operator OperandRange() const {
return owner->getOperands().slice(start, length);
}
+MutableOperandRangeRange
+MutableOperandRange::split(NamedAttribute segmentSizes) const {
+ return MutableOperandRangeRange(*this, segmentSizes);
+}
+
/// Update the length of this range to the one provided.
void MutableOperandRange::updateLength(unsigned newLength) {
int32_t
diff = int32_t(newLength) - int32_t(length);
@@ -490,6 +522,35 @@ void MutableOperandRange::updateLength(unsigned newLength) {
}
}
+//===----------------------------------------------------------------------===//
+// MutableOperandRangeRange
+
+MutableOperandRangeRange::MutableOperandRangeRange(
+ const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
+ : MutableOperandRangeRange(
+ OwnerT(operands, operandSegmentAttr), 0,
+ operandSegmentAttr.second.cast<DenseElementsAttr>().size()) {}
+
+MutableOperandRange MutableOperandRangeRange::join() const {
+ return getBase().first;
+}
+
+MutableOperandRangeRange::operator OperandRangeRange() const {
+ return OperandRangeRange(getBase().first,
+ getBase().second.second.cast<DenseElementsAttr>());
+}
+
+MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
+ ptr
diff _t index) {
+ auto sizeData =
+ object.second.second.cast<DenseElementsAttr>().getValues<uint32_t>();
+ uint32_t startIndex =
+ std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
+ return object.first.slice(
+ startIndex, *(sizeData.begin() + index),
+ MutableOperandRange::OperandSegment(index, object.second));
+}
+
//===----------------------------------------------------------------------===//
// ValueRange
diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp
index b724f2175bed2..c847760f8a467 100644
--- a/mlir/lib/TableGen/Argument.cpp
+++ b/mlir/lib/TableGen/Argument.cpp
@@ -12,6 +12,10 @@
using namespace mlir;
using namespace mlir::tblgen;
+//===----------------------------------------------------------------------===//
+// NamedTypeConstraint
+//===----------------------------------------------------------------------===//
+
bool NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
@@ -19,3 +23,7 @@ bool NamedTypeConstraint::hasPredicate() const {
bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); }
bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); }
+
+bool NamedTypeConstraint::isVariadicOfVariadic() const {
+ return constraint.isVariadicOfVariadic();
+}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index ea9513d4e6647..03e5170e52688 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -458,6 +458,13 @@ void Operator::populateOpStructure() {
results.push_back({name, TypeConstraint(resultDef)});
if (!name.empty())
argumentsAndResultsIndex[name] = resultIndex(i);
+
+ // We currently only support VariadicOfVariadic operands.
+ if (results.back().constraint.isVariadicOfVariadic()) {
+ PrintFatalError(
+ def.getLoc(),
+ "'VariadicOfVariadic' results are currently not supported");
+ }
}
// Handle successors
@@ -577,8 +584,7 @@ bool Operator::hasAssemblyFormat() const {
StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
- .Case<llvm::StringInit>(
- [&](auto *init) { return init->getValue(); });
+ .Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
}
void Operator::print(llvm::raw_ostream &os) const {
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index fd5a0f7058979..6691bb88c7680 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -36,6 +36,15 @@ bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
+bool TypeConstraint::isVariadicOfVariadic() const {
+ return def->isSubClassOf("VariadicOfVariadic");
+}
+
+StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const {
+ assert(isVariadicOfVariadic());
+ return def->getValueAsString("segmentAttrName");
+}
+
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 3e1bb3bee0b39..2b5e314e5555f 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -375,28 +375,28 @@ func private @foo()
// -----
func @failedMissingOperandSizeAttr(%arg: i32) {
- // expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongType(%arg: i32) {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
- "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : (i32, i32, i32, i32) -> ()
+ // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
+ "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongRank(%arg: i32) {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongElementType(%arg: i32) {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : (i32, i32, i32, i32) -> ()
}
@@ -432,28 +432,28 @@ func @succeededOperandSizeAttr(%arg: i32) {
// -----
func @failedMissingResultSizeAttr() {
- // expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongType() {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
- %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : () -> (i32, i32, i32, i32)
+ // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
+ %0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongRank() {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongElementType() {
- // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : () -> (i32, i32, i32, i32)
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index add66b421f1f2..fbbc766839c40 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1661,6 +1661,14 @@ def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> {
let arguments = (ins Variadic<I64>:$operand);
let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
}
+def FormatVariadicOfVariadicOperand
+ : TEST_Op<"format_variadic_of_variadic_operand"> {
+ let arguments = (ins
+ VariadicOfVariadic<I64, "operand_segments">:$operand,
+ I32ElementsAttr:$operand_segments
+ );
+ let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
+}
def FormatMultipleVariadicOperands :
TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> {
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 759e7f5e3abdd..ccaedba466597 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -151,6 +151,9 @@ test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
// CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64
test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64
+// CHECK: test.format_variadic_of_variadic_operand (%[[I64]], %[[I64]]), (), (%[[I64]]) : (i64, i64), (), (i64)
+test.format_variadic_of_variadic_operand (%i64, %i64), (), (%i64) : (i64, i64), (), (i64)
+
// CHECK: test.format_multiple_variadic_operands (%[[I64]], %[[I64]], %[[I64]]), (%[[I64]], %[[I32]] : i64, i32)
test.format_multiple_variadic_operands (%i64, %i64, %i64), (%i64, %i32 : i64, i32)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d2b921fb5dd49..220f8438aebe3 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -24,6 +24,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -89,6 +90,23 @@ const char *attrSizedSegmentValueRangeCalcCode = R"(
unsigned size = *(sizeAttrValues.begin() + index);
return {start, size};
)";
+// The logic to calculate the actual value range for a declared operand
+// of an op with variadic of variadic operands within the OpAdaptor.
+//
+// {0}: The name of the segment attribute.
+// {1}: The index of the main operand.
+const char *variadicOfVariadicAdaptorCalcCode = R"(
+ auto tblgenTmpOperands = getODSOperands({1});
+ auto sizeAttrValues = {0}().getValues<uint32_t>();
+ auto sizeAttrIt = sizeAttrValues.begin();
+
+ ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
+ for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
+ tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
+ tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
+ }
+ return tblgenTmpOperandGroups;
+)";
// The logic to build a range of either operand or result values.
//
@@ -256,16 +274,20 @@ class OpEmitter {
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list and updates
// `resultTypeNames` with the names for parameters for specifying result
- // types. The given `typeParamKind` and `attrParamKind` controls how result
- // types and attributes are placed in the parameter list.
+ // types. `inferredAttributes` is populated with any attributes that are
+ // elided from the build list. The given `typeParamKind` and `attrParamKind`
+ // controls how result types and attributes are placed in the parameter list.
void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
+ llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods.
- void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
- bool isRawValueAttr = false);
+ void
+ genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+ llvm::StringSet<> &inferredAttributes,
+ bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
@@ -783,7 +805,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
- StringRef sizeAttrInit,
+ bool isAdaptor, StringRef sizeAttrInit,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
@@ -838,6 +860,20 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
m->body()
<< " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : *operands.begin();";
+ } else if (operand.isVariadicOfVariadic()) {
+ StringRef segmentAttr =
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr();
+ if (isAdaptor) {
+ m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>",
+ operand.name);
+ m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
+ segmentAttr, i);
+ continue;
+ }
+
+ m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name);
+ m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
+ << "Attr());";
} else if (operand.isVariadic()) {
m = opClass.addMethodAndPrune(rangeType, operand.name);
m->body() << " return getODSOperands(" << i << ");";
@@ -860,6 +896,7 @@ void OpEmitter::genNamedOperandGetters() {
generateNamedOperandGetters(
op, opClass,
+ /*isAdaptor=*/false,
/*sizeAttrInit=*/attrSizeInitCode,
/*rangeType=*/"::mlir::Operation::operand_range",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
@@ -874,17 +911,32 @@ void OpEmitter::genNamedOperandSetters() {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
- auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
+ auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic()
+ ? "::mlir::MutableOperandRangeRange"
+ : "::mlir::MutableOperandRange",
(operand.name + "Mutable").str());
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
- << " return ::mlir::MutableOperandRange(getOperation(), "
+ << " auto mutableRange = ::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands)
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
<< "u, *getOperation()->getAttrDictionary().getNamed("
"operand_segment_sizesAttrName()))";
body << ");\n";
+
+ // If this operand is a nested variadic, we split the range into a
+ // MutableOperandRangeRange that provides a range over all of the
+ // sub-ranges.
+ if (operand.isVariadicOfVariadic()) {
+ body << " return "
+ "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
+ << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
+ << "AttrName()));\n";
+ } else {
+ // Otherwise, we use the full range directly.
+ body << " return mutableRange;\n";
+ }
}
}
@@ -1038,7 +1090,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, paramKind, attrType);
+ llvm::StringSet<> inferredAttributes;
+ buildParamList(paramList, inferredAttributes, resultNames, paramKind,
+ attrType);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
@@ -1046,8 +1100,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
if (!m)
return;
auto &body = m->body();
- genCodeForAddingArgAndRegionForBuilder(
- body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
+ genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
+ /*isRawValueAttr=*/attrType ==
+ AttrParamKind::UnwrappedValue);
// Push all result types to the operation state
@@ -1215,7 +1270,9 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, TypeParamKind::None);
+ llvm::StringSet<> inferredAttributes;
+ buildParamList(paramList, inferredAttributes, resultNames,
+ TypeParamKind::None);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
@@ -1223,7 +1280,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
if (!m)
return;
auto &body = m->body();
- genCodeForAddingArgAndRegionForBuilder(body);
+ genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes);
auto numResults = op.getNumResults();
if (numResults == 0)
@@ -1415,6 +1472,7 @@ void OpEmitter::genCollectiveParamBuilder() {
}
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
+ llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind) {
@@ -1453,10 +1511,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
}
// Add parameters for all arguments (operands and attributes).
-
- int numOperands = 0;
- int numAttrs = 0;
-
int defaultValuedAttrStartIndex = op.getNumArgs();
if (attrParamKind == AttrParamKind::UnwrappedValue) {
// Calculate the start index from which we can attach default values in the
@@ -1482,54 +1536,68 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
}
}
- for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
- auto argument = op.getArg(i);
- if (argument.is<tblgen::NamedTypeConstraint *>()) {
- const auto &operand = op.getOperand(numOperands);
- StringRef type =
- operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
- OpMethodParameter::Property properties = OpMethodParameter::PP_None;
- if (operand.isOptional())
- properties = OpMethodParameter::PP_Optional;
+ /// Collect any inferred attributes.
+ for (const NamedTypeConstraint &operand : op.getOperands()) {
+ if (operand.isVariadicOfVariadic()) {
+ inferredAttributes.insert(
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
+ }
+ }
- paramList.emplace_back(type, getArgumentName(op, numOperands),
- properties);
- ++numOperands;
- } else {
- const auto &namedAttr = op.getAttribute(numAttrs);
- const auto &attr = namedAttr.attr;
+ for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
+ Argument arg = op.getArg(i);
+ if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
+ StringRef type;
+ if (operand->isVariadicOfVariadic())
+ type = "::llvm::ArrayRef<::mlir::ValueRange>";
+ else if (operand->isVariadic())
+ type = "::mlir::ValueRange";
+ else
+ type = "::mlir::Value";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
- if (attr.isOptional())
+ if (operand->isOptional())
properties = OpMethodParameter::PP_Optional;
+ paramList.emplace_back(type, getArgumentName(op, numOperands++),
+ properties);
+ continue;
+ }
+ const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
+ const Attribute &attr = namedAttr.attr;
- StringRef type;
- switch (attrParamKind) {
- case AttrParamKind::WrappedAttr:
+ // inferred attributes don't need to be added to the param list.
+ if (inferredAttributes.contains(namedAttr.name))
+ continue;
+
+ OpMethodParameter::Property properties = OpMethodParameter::PP_None;
+ if (attr.isOptional())
+ properties = OpMethodParameter::PP_Optional;
+
+ StringRef type;
+ switch (attrParamKind) {
+ case AttrParamKind::WrappedAttr:
+ type = attr.getStorageType();
+ break;
+ case AttrParamKind::UnwrappedValue:
+ if (canUseUnwrappedRawValue(attr))
+ type = attr.getReturnType();
+ else
type = attr.getStorageType();
- break;
- case AttrParamKind::UnwrappedValue:
- if (canUseUnwrappedRawValue(attr))
- type = attr.getReturnType();
- else
- type = attr.getStorageType();
- break;
- }
+ break;
+ }
- std::string defaultValue;
- // Attach default value if requested and possible.
- if (attrParamKind == AttrParamKind::UnwrappedValue &&
- i >= defaultValuedAttrStartIndex) {
- bool isString = attr.getReturnType() == "::llvm::StringRef";
- if (isString)
- defaultValue.append("\"");
- defaultValue += attr.getDefaultValue();
- if (isString)
- defaultValue.append("\"");
- }
- paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
- ++numAttrs;
+ // Attach default value if requested and possible.
+ std::string defaultValue;
+ if (attrParamKind == AttrParamKind::UnwrappedValue &&
+ i >= defaultValuedAttrStartIndex) {
+ bool isString = attr.getReturnType() == "::llvm::StringRef";
+ if (isString)
+ defaultValue.append("\"");
+ defaultValue += attr.getDefaultValue();
+ if (isString)
+ defaultValue.append("\"");
}
+ paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
}
/// Insert parameters for each successor.
@@ -1546,12 +1614,31 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
llvm::formatv("{0}Count", region.name).str());
}
-void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
- bool isRawValueAttr) {
+void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
+ OpMethodBody &body, llvm::StringSet<> &inferredAttributes,
+ bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i);
- if (op.getOperand(i).isOptional())
+ NamedTypeConstraint &operand = op.getOperand(i);
+ if (operand.constraint.isVariadicOfVariadic()) {
+ body << " for (::mlir::ValueRange range : " << argName << ")\n "
+ << builderOpState << ".addOperands(range);\n";
+
+ // Add the segment attribute.
+ body << " {\n"
+ << " SmallVector<int32_t> rangeSegments;\n"
+ << " for (::mlir::ValueRange range : " << argName << ")\n"
+ << " rangeSegments.push_back(range.size());\n"
+ << " " << builderOpState << ".addAttribute("
+ << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
+ << "AttrName(" << builderOpState << ".name), " << odsBuilder
+ << ".getI32TensorAttr(rangeSegments));"
+ << " }\n";
+ continue;
+ }
+
+ if (operand.isOptional())
body << " if (" << argName << ")\n ";
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
@@ -1563,12 +1650,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
<< ".name), "
<< "odsBuilder.getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
- if (op.getOperand(i).isOptional())
- body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
- else if (op.getOperand(i).isVariadic())
- body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
- else
+ const NamedTypeConstraint &operand = op.getOperand(i);
+ if (!operand.isVariableLength()) {
body << "1";
+ return;
+ }
+
+ std::string operandName = getArgumentName(op, i);
+ if (operand.isOptional()) {
+ body << "(" << operandName << " ? 1 : 0)";
+ } else if (operand.isVariadicOfVariadic()) {
+ body << llvm::formatv(
+ "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
+ "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
+ "range.size(); }))",
+ operandName);
+ } else {
+ body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
+ }
});
body << "}));\n";
}
@@ -1576,38 +1675,38 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
// Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
- if (!attr.isDerivedAttr()) {
- bool emitNotNullCheck = attr.isOptional();
- if (emitNotNullCheck)
- body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
-
- if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
- // If this is a raw value, then we need to wrap it in an Attribute
- // instance.
- FmtContext fctx;
- fctx.withBuilder("odsBuilder");
-
- std::string builderTemplate =
- std::string(attr.getConstBuilderTemplate());
-
- // For StringAttr, its constant builder call will wrap the input in
- // quotes, which is correct for normal string literals, but incorrect
- // here given we use function arguments. So we need to strip the
- // wrapping quotes.
- if (StringRef(builderTemplate).contains("\"$0\""))
- builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
-
- std::string value =
- std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
- builderOpState, namedAttr.name, value);
- } else {
- body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n",
- builderOpState, namedAttr.name);
- }
- if (emitNotNullCheck)
- body << " }\n";
+ if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
+ continue;
+
+ bool emitNotNullCheck = attr.isOptional();
+ if (emitNotNullCheck)
+ body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
+
+ if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
+ // If this is a raw value, then we need to wrap it in an Attribute
+ // instance.
+ FmtContext fctx;
+ fctx.withBuilder("odsBuilder");
+
+ std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
+
+ // For StringAttr, its constant builder call will wrap the input in
+ // quotes, which is correct for normal string literals, but incorrect
+ // here given we use function arguments. So we need to strip the
+ // wrapping quotes.
+ if (StringRef(builderTemplate).contains("\"$0\""))
+ builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
+
+ std::string value =
+ std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+ builderOpState, namedAttr.name, value);
+ } else {
+ body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n",
+ builderOpState, namedAttr.name);
}
+ if (emitNotNullCheck)
+ body << " }\n";
}
// Create the correct number of regions.
@@ -1960,9 +2059,12 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
- bool hasPredicate = staticValue.value().hasPredicate();
- bool isOptional = staticValue.value().isOptional();
- if (!hasPredicate && !isOptional)
+ const NamedTypeConstraint &value = staticValue.value();
+
+ bool hasPredicate = value.hasPredicate();
+ bool isOptional = value.isOptional();
+ bool isVariadicOfVariadic = value.isVariadicOfVariadic();
+ if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
continue;
body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
@@ -1977,14 +2079,21 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
+ } else if (isVariadicOfVariadic) {
+ body << formatv(
+ " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
+ "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
+ " return ::mlir::failure();\n",
+ value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
+ staticValue.index());
}
// Otherwise, if there is no predicate there is nothing left to do.
if (!hasPredicate)
continue;
// Emit a loop to check all the dynamic values in the pack.
- StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
- staticValue.value().constraint);
+ StringRef constraintFn =
+ staticVerifierEmitter.getTypeConstraintFn(value.constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
@@ -2257,7 +2366,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
}
std::string sizeAttrInit =
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
- generateNamedOperandGetters(op, adaptor, sizeAttrInit,
+ generateNamedOperandGetters(op, adaptor,
+ /*isAdaptor=*/true, sizeAttrInit,
/*rangeType=*/"::mlir::ValueRange",
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 2c91708af5be5..675211aa223af 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -497,6 +497,7 @@ struct OperationFormat {
/// The set of attributes explicitly used within the format.
SmallVector<const NamedAttribute *, 8> usedAttributes;
+ llvm::StringSet<> inferredAttributes;
};
} // end anonymous namespace
@@ -616,10 +617,38 @@ const char *const operandParserCode = R"(
if (parser.parseOperand({0}RawOperands[0]))
return ::mlir::failure();
)";
+/// The code snippet used to generate a parser call for a VariadicOfVariadic
+/// operand.
+///
+/// {0}: The name of the operand.
+/// {1}: The name of segment size attribute.
+const char *const variadicOfVariadicOperandParserCode = R"(
+ {
+ {0}OperandsLoc = parser.getCurrentLocation();
+ int32_t curSize = 0;
+ do {
+ if (parser.parseOptionalLParen())
+ break;
+ if (parser.parseOperandList({0}Operands) || parser.parseRParen())
+ return ::mlir::failure();
+ {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
+ curSize = {0}Operands.size();
+ } while (succeeded(parser.parseOptionalComma()));
+ }
+)";
/// The code snippet used to generate a parser call for a type list.
///
/// {0}: The name for the type list.
+const char *const variadicOfVariadicTypeParserCode = R"(
+ do {
+ if (parser.parseOptionalLParen())
+ break;
+ if (parser.parseOptionalRParen() &&
+ (parser.parseTypeList({0}Types) || parser.parseRParen()))
+ return ::mlir::failure();
+ } while (succeeded(parser.parseOptionalComma()));
+)";
const char *const variadicTypeParserCode = R"(
if (parser.parseTypeList({0}Types))
return ::mlir::failure();
@@ -758,6 +787,9 @@ const char *successorParserCode = R"(
namespace {
/// The type of length for a given parse argument.
enum class ArgumentLengthKind {
+ /// The argument is a variadic of a variadic, and may contain 0->N range
+ /// elements.
+ VariadicOfVariadic,
/// The argument is variadic, and may contain 0->N elements.
Variadic,
/// The argument is optional, and may contain 0 or 1 elements.
@@ -772,6 +804,8 @@ static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint *var) {
if (var->isOptional())
return ArgumentLengthKind::Optional;
+ if (var->isVariadicOfVariadic())
+ return ArgumentLengthKind::VariadicOfVariadic;
if (var->isVariadic())
return ArgumentLengthKind::Variadic;
return ArgumentLengthKind::Single;
@@ -863,6 +897,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (operand->getVar()->isVariableLength()) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
<< name << "Operands;\n";
+ if (operand->getVar()->isVariadicOfVariadic()) {
+ body << " llvm::SmallVector<int32_t> " << name
+ << "OperandGroupSizes;\n";
+ }
} else {
body << " ::mlir::OpAsmParser::OperandType " << name
<< "RawOperands[1];\n"
@@ -924,7 +962,9 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
- if (lengthKind == ArgumentLengthKind::Variadic)
+ if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+ body << llvm::formatv("{0}OperandGroups", name);
+ else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Operands", name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Operand", name);
@@ -951,7 +991,9 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
- if (lengthKind == ArgumentLengthKind::Variadic)
+ if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+ body << llvm::formatv("{0}TypeGroups", listName);
+ else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Types", listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Type", listName);
@@ -972,19 +1014,32 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
// * Set the location of operand variables.
for (Element ¶m : dir->getArguments()) {
if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
- body << " " << operand->getVar()->name
+ auto *var = operand->getVar();
+ body << " " << var->name
<< "OperandsLoc = parser.getCurrentLocation();\n";
- if (operand->getVar()->isOptional()) {
+ if (var->isOptional()) {
body << llvm::formatv(
" llvm::Optional<::mlir::OpAsmParser::OperandType> "
"{0}Operand;\n",
- operand->getVar()->name);
+ var->name);
+ } else if (var->isVariadicOfVariadic()) {
+ body << llvm::formatv(" "
+ "llvm::SmallVector<llvm::SmallVector<::mlir::"
+ "OpAsmParser::OperandType>> "
+ "{0}OperandGroups;\n",
+ var->name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
- if (lengthKind == ArgumentLengthKind::Optional)
+ if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
+ } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
+ body << llvm::formatv(
+ " llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
+ "{0}TypeGroups;\n",
+ listName);
+ }
} else if (auto *dir = dyn_cast<RefDirective>(¶m)) {
Element *input = dir->getOperand();
if (auto *operand = dyn_cast<OperandVariable>(input)) {
@@ -1028,11 +1083,18 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
const NamedTypeConstraint *var = operand->getVar();
- if (!var->isOptional())
- continue;
- body << llvm::formatv(" if ({0}Operand.hasValue())\n"
- " {0}Operands.push_back(*{0}Operand);\n",
- var->name);
+ if (var->isOptional()) {
+ body << llvm::formatv(" if ({0}Operand.hasValue())\n"
+ " {0}Operands.push_back(*{0}Operand);\n",
+ var->name);
+ } else if (var->isVariadicOfVariadic()) {
+ body << llvm::formatv(
+ " for (const auto &subRange : {0}OperandGroups) {{\n"
+ " {0}Operands.append(subRange.begin(), subRange.end());\n"
+ " {0}OperandGroupSizes.push_back(subRange.size());\n"
+ " }\n",
+ var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
+ }
} else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -1040,6 +1102,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << llvm::formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
listName);
+ } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
+ body << llvm::formatv(
+ " for (const auto &subRange : {0}TypeGroups)\n"
+ " {0}Types.append(subRange.begin(), subRange.end());\n",
+ listName);
}
}
}
@@ -1229,7 +1296,11 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
- if (lengthKind == ArgumentLengthKind::Variadic)
+ if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+ body << llvm::formatv(
+ variadicOfVariadicOperandParserCode, name,
+ operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
+ else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalOperandParserCode, name);
@@ -1281,7 +1352,9 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
- if (lengthKind == ArgumentLengthKind::Variadic)
+ if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
+ body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
+ else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalTypeParserCode, listName);
@@ -1501,19 +1574,29 @@ void OperationFormat::genParserSuccessorResolution(Operator &op,
void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
OpMethodBody &body) {
- if (!allOperands &&
- op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- body << " result.addAttribute(\"operand_segment_sizes\", "
- << "parser.getBuilder().getI32VectorAttr({";
- auto interleaveFn = [&](const NamedTypeConstraint &operand) {
- // If the operand is variadic emit the parsed size.
- if (operand.isVariableLength())
- body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
- else
- body << "1";
- };
- llvm::interleaveComma(op.getOperands(), body, interleaveFn);
- body << "}));\n";
+ if (!allOperands) {
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ body << " result.addAttribute(\"operand_segment_sizes\", "
+ << "parser.getBuilder().getI32VectorAttr({";
+ auto interleaveFn = [&](const NamedTypeConstraint &operand) {
+ // If the operand is variadic emit the parsed size.
+ if (operand.isVariableLength())
+ body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
+ else
+ body << "1";
+ };
+ llvm::interleaveComma(op.getOperands(), body, interleaveFn);
+ body << "}));\n";
+ }
+ for (const NamedTypeConstraint &operand : op.getOperands()) {
+ if (!operand.isVariadicOfVariadic())
+ continue;
+ body << llvm::formatv(
+ " result.addAttribute(\"{0}\", "
+ "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n",
+ operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
+ operand.name);
+ }
}
if (!allResultTypes &&
@@ -1575,6 +1658,10 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
if (!fmt.allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
body << "\"result_segment_sizes\", ";
+ if (!fmt.inferredAttributes.empty()) {
+ for (const auto &attr : fmt.inferredAttributes)
+ body << "\"" << attr.getKey() << "\", ";
+ }
llvm::interleaveComma(
fmt.usedAttributes, body,
[&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
@@ -1693,6 +1780,8 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
return body << "getOperation()->getResultTypes()";
auto *operand = dyn_cast<OperandVariable>(arg);
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
+ if (var->isVariadicOfVariadic())
+ return body << llvm::formatv("{0}().join().getTypes()", var->name);
if (var->isVariadic())
return body << var->name << "().getTypes()";
if (var->isOptional())
@@ -1896,7 +1985,12 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
- if (operand->getVar()->isOptional()) {
+ if (operand->getVar()->isVariadicOfVariadic()) {
+ body << " ::llvm::interleaveComma(" << operand->getVar()->name
+ << "(), p, [&](const auto &operands) { p << \"(\" << operands << "
+ "\")\"; });\n";
+
+ } else if (operand->getVar()->isOptional()) {
body << " if (::mlir::Value value = " << operand->getVar()->name
<< "())\n"
<< " p << value;\n";
@@ -1926,6 +2020,15 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
+ if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) {
+ if (operand->getVar()->isVariadicOfVariadic()) {
+ body << llvm::formatv(" ::llvm::interleaveComma({0}().getTypes(), p, "
+ "[&](::mlir::TypeRange types) {{ p << \"(\" << "
+ "types << \")\"; });\n",
+ operand->getVar()->name);
+ return;
+ }
+ }
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
@@ -2449,6 +2552,16 @@ LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
while (!iteratorStack.empty())
if (failed(verifyAttributes(loc, iteratorStack)))
return ::mlir::failure();
+
+ // Check for VariadicOfVariadic variables. The segment attribute of those
+ // variables will be infered.
+ for (const NamedTypeConstraint *var : seenOperands) {
+ if (var->constraint.isVariadicOfVariadic()) {
+ fmt.inferredAttributes.insert(
+ var->constraint.getVariadicOfVariadicSegmentSizeAttr());
+ }
+ }
+
return ::mlir::success();
}
/// Verify the attribute elements at the back of the given stack of iterators.
More information about the Mlir-commits
mailing list