[Mlir-commits] [mlir] aba1acc - [mlir][ODS] Add support for optional operands and results with a new Optional directive.
River Riddle
llvmlistbot at llvm.org
Fri Apr 10 14:18:02 PDT 2020
Author: River Riddle
Date: 2020-04-10T14:12:06-07:00
New Revision: aba1acc89c653b2cc08cccfb754ff16994a05332
URL: https://github.com/llvm/llvm-project/commit/aba1acc89c653b2cc08cccfb754ff16994a05332
DIFF: https://github.com/llvm/llvm-project/commit/aba1acc89c653b2cc08cccfb754ff16994a05332.diff
LOG: [mlir][ODS] Add support for optional operands and results with a new Optional directive.
Summary: This revision adds support for specifying operands or results as "optional". This is a special case of variadic where the number of elements is either 0 or 1. Operands and results of this kind will have accessors generated using Value instead of the range types, making it more natural to interface with.
Differential Revision: https://reviews.llvm.org/D77863
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/TableGen/Argument.h
mlir/include/mlir/TableGen/Operator.h
mlir/include/mlir/TableGen/Type.h
mlir/lib/Parser/Parser.cpp
mlir/lib/TableGen/Argument.cpp
mlir/lib/TableGen/Operator.cpp
mlir/lib/TableGen/Pattern.cpp
mlir/lib/TableGen/Type.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/test/mlir-tblgen/predicate.td
mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 0187ff740bf7..11d1533d7002 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -221,11 +221,28 @@ To declare a variadic operand, wrap the `TypeConstraint` for the operand with
Normally operations have no variadic operands or just one variadic operand. For
the latter case, it is easy to deduce which dynamic operands are for the static
-variadic operand definition. But if an operation has more than one variadic
-operands, it would be impossible to attribute dynamic operands to the
-corresponding static variadic operand definitions without further information
-from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to
-indicate that all variadic operands have the same number of dynamic values.
+variadic operand definition. Though, if an operation has more than one variable
+length operands (either optional or variadic), it would be impossible to
+attribute dynamic operands to the corresponding static variadic operand
+definitions without further information from the operation. Therefore, either
+the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
+indicate that all variable length operands have the same number of dynamic
+values.
+
+#### Optional operands
+
+To declare an optional operand, wrap the `TypeConstraint` for the operand with
+`Optional<...>`.
+
+Normally operations have no optional operands or just one optional operand. For
+the latter case, it is easy to deduce which dynamic operands are for the static
+operand definition. Though, if an operation has more than one variable length
+operands (either optional or variadic), it would be impossible to attribute
+dynamic operands to the corresponding static variadic operand definitions
+without further information from the operation. Therefore, either the
+`SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
+indicate that all variable length operands have the same number of dynamic
+values.
#### Optional attributes
@@ -693,7 +710,7 @@ information. An optional group is defined by wrapping a set of elements within
the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- - Only variadic, i.e. optional, operand arguments can be used.
+ - Only variadic or optional operand arguments can be used.
- The operands to a type directive must be defined within the optional
group.
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 58c9643f3efb..56970f7a5aa7 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -297,13 +297,17 @@ class DialectType<Dialect d, Pred condition, string descr = ""> :
}
// A variadic type constraint. It expands to zero or more of the base type. This
-// class is used for supporting variadic operands/results. An op can declare no
-// more than one variadic operand/result, and that operand/result must be the
-// last one in the operand/result list.
+// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}
+// An optional type constraint. It expands to either zero or one of the base
+// type. This class is used for supporting optional operands/results.
+class Optional<Type type> : TypeConstraint<type.predicate, type.description> {
+ Type baseType = type;
+}
+
// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0a2602544112..cad162cbb3f8 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -621,6 +621,9 @@ class OpAsmParser {
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
+ /// Parse an optional type.
+ virtual OptionalParseResult parseOptionalType(Type &result) = 0;
+
/// Parse a type of a specific type.
template <typename TypeT>
ParseResult parseType(TypeT &result) {
diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h
index 660e1bbc4bae..0eb4d8ce4198 100644
--- a/mlir/include/mlir/TableGen/Argument.h
+++ b/mlir/include/mlir/TableGen/Argument.h
@@ -43,8 +43,14 @@ struct NamedAttribute {
struct NamedTypeConstraint {
// Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const;
+ // Returns true if this is an optional type constraint. This is a special case
+ // of variadic for 0 or 1 type.
+ bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
+ // Returns true if this is a variable length type constraint. This is either
+ // variadic or optional.
+ bool isVariableLength() const { return isOptional() || isVariadic(); }
llvm::StringRef name;
TypeConstraint constraint;
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 2748894fe601..e65bc55a84f5 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -88,7 +88,7 @@ class Operator {
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
- // Returns true if this op has variadic operands or results.
+ // Returns true if this op has variable length operands or results.
bool isVariadic() const;
// Returns true if default builders should not be generated.
@@ -115,8 +115,8 @@ class Operator {
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;
- // Returns the number of variadic results in this operation.
- unsigned getNumVariadicResults() const;
+ // Returns the number of variable length results in this operation.
+ unsigned getNumVariableLengthResults() const;
// Op attribute iterators.
using attribute_iterator = const NamedAttribute *;
@@ -142,7 +142,7 @@ class Operator {
}
// Returns the number of variadic operands in this operation.
- unsigned getNumVariadicOperands() const;
+ unsigned getNumVariableLengthOperands() const;
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index b6fb1edb09ef..2653b90196f7 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -34,9 +34,16 @@ class TypeConstraint : public Constraint {
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
+ // Returns true if this is an optional type constraint.
+ bool isOptional() const;
+
// Returns true if this is a variadic type constraint.
bool isVariadic() const;
+ // Returns true if this is a variable length type constraint. This is either
+ // variadic or optional.
+ bool isVariableLength() const { return isOptional() || isVariadic(); }
+
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e68867b1e44e..e339588d9661 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -227,6 +227,9 @@ class Parser {
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
+ /// Optionally parse a type.
+ OptionalParseResult parseOptionalType(Type &type);
+
/// Parse an arbitrary type.
Type parseType();
@@ -899,6 +902,31 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
// Type Parsing
//===----------------------------------------------------------------------===//
+/// Optionally parse a type.
+OptionalParseResult Parser::parseOptionalType(Type &type) {
+ // There are many
diff erent starting tokens for a type, check them here.
+ switch (getToken().getKind()) {
+ case Token::l_paren:
+ case Token::kw_memref:
+ case Token::kw_tensor:
+ case Token::kw_complex:
+ case Token::kw_tuple:
+ case Token::kw_vector:
+ case Token::inttype:
+ case Token::kw_bf16:
+ case Token::kw_f16:
+ case Token::kw_f32:
+ case Token::kw_f64:
+ case Token::kw_index:
+ case Token::kw_none:
+ case Token::exclamation_identifier:
+ return failure(!(type = parseType()));
+
+ default:
+ return llvm::None;
+ }
+}
+
/// Parse an arbitrary type.
///
/// type ::= function-type
@@ -4509,6 +4537,11 @@ class CustomOpAsmParser : public OpAsmParser {
return failure(!(result = parser.parseType()));
}
+ /// Parse an optional type.
+ OptionalParseResult parseOptionalType(Type &result) override {
+ return parser.parseOptionalType(result);
+ }
+
/// Parse an arrow followed by a type list.
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
if (parseArrow() || parser.parseFunctionResultTypes(result))
diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp
index eb68b79e1857..1fea24d3bad0 100644
--- a/mlir/lib/TableGen/Argument.cpp
+++ b/mlir/lib/TableGen/Argument.cpp
@@ -15,6 +15,10 @@ bool tblgen::NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
+bool tblgen::NamedTypeConstraint::isOptional() const {
+ return constraint.isOptional();
+}
+
bool tblgen::NamedTypeConstraint::isVariadic() const {
return constraint.isVariadic();
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 46e26af40bdf..a6bed62948b7 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -81,10 +81,6 @@ StringRef tblgen::Operator::getExtraClassDeclaration() const {
const llvm::Record &tblgen::Operator::getDef() const { return def; }
-bool tblgen::Operator::isVariadic() const {
- return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
-}
-
bool tblgen::Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders");
}
@@ -119,16 +115,16 @@ auto tblgen::Operator::getResultDecorators(int index) const
return *result->getValueAsListInit("decorators");
}
-unsigned tblgen::Operator::getNumVariadicResults() const {
- return std::count_if(
- results.begin(), results.end(),
- [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
+unsigned tblgen::Operator::getNumVariableLengthResults() const {
+ return llvm::count_if(results, [](const NamedTypeConstraint &c) {
+ return c.constraint.isVariableLength();
+ });
}
-unsigned tblgen::Operator::getNumVariadicOperands() const {
- return std::count_if(
- operands.begin(), operands.end(),
- [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
+unsigned tblgen::Operator::getNumVariableLengthOperands() const {
+ return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
+ return c.constraint.isVariableLength();
+ });
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index d832ea809247..b04c8e215679 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -255,7 +255,7 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
- if (operand->isVariadic()) {
+ if (operand->isVariableLength()) {
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return std::string(repl);
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 4105c48292fd..6cf2a5f4a175 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -26,6 +26,10 @@ TypeConstraint::TypeConstraint(const llvm::Record *record)
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}
+bool TypeConstraint::isOptional() const {
+ return def->isSubClassOf("Optional");
+}
+
bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
@@ -34,7 +38,7 @@ bool TypeConstraint::isVariadic() const {
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::Record *baseType = def;
- if (isVariadic())
+ if (isVariableLength())
baseType = baseType->getValueAsDef("baseType");
// Check to see if this type constraint has a builder call.
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 86ebfe4108a2..6f1ef4a50f67 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1179,39 +1179,41 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
}
// Test various mixings of result type formatting.
-class FormatResultBase<string name, string fmt> : TEST_Op<name> {
+class FormatResultBase<string suffix, string fmt>
+ : TEST_Op<"format_result_" # suffix # "_op"> {
let results = (outs I64:$buildable_res, AnyMemRef:$result);
let assemblyFormat = fmt;
}
-def FormatResultAOp : FormatResultBase<"format_result_a_op", [{
+def FormatResultAOp : FormatResultBase<"a", [{
type($result) attr-dict
}]>;
-def FormatResultBOp : FormatResultBase<"format_result_b_op", [{
+def FormatResultBOp : FormatResultBase<"b", [{
type(results) attr-dict
}]>;
-def FormatResultCOp : FormatResultBase<"format_result_c_op", [{
+def FormatResultCOp : FormatResultBase<"c", [{
functional-type($buildable_res, $result) attr-dict
}]>;
// Test various mixings of operand type formatting.
-class FormatOperandBase<string name, string fmt> : TEST_Op<name> {
+class FormatOperandBase<string suffix, string fmt>
+ : TEST_Op<"format_operand_" # suffix # "_op"> {
let arguments = (ins I64:$buildable, AnyMemRef:$operand);
let assemblyFormat = fmt;
}
-def FormatOperandAOp : FormatOperandBase<"format_operand_a_op", [{
+def FormatOperandAOp : FormatOperandBase<"a", [{
operands `:` type(operands) attr-dict
}]>;
-def FormatOperandBOp : FormatOperandBase<"format_operand_b_op", [{
+def FormatOperandBOp : FormatOperandBase<"b", [{
operands `:` type($operand) attr-dict
}]>;
-def FormatOperandCOp : FormatOperandBase<"format_operand_c_op", [{
+def FormatOperandCOp : FormatOperandBase<"c", [{
$buildable `,` $operand `:` type(operands) attr-dict
}]>;
-def FormatOperandDOp : FormatOperandBase<"format_operand_d_op", [{
+def FormatOperandDOp : FormatOperandBase<"d", [{
$buildable `,` $operand `:` type($operand) attr-dict
}]>;
-def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
+def FormatOperandEOp : FormatOperandBase<"e", [{
$buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
}]>;
@@ -1220,6 +1222,25 @@ def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
let assemblyFormat = "$targets attr-dict";
}
+// Test various mixings of optional operand and result type formatting.
+class FormatOptionalOperandResultOpBase<string suffix, string fmt>
+ : TEST_Op<"format_optional_operand_result_" # suffix # "_op",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins Optional<I64>:$optional, Variadic<I64>:$variadic);
+ let results = (outs Optional<I64>:$optional_res);
+ let assemblyFormat = fmt;
+}
+
+def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{
+ `(` $optional `:` type($optional) `)` `:` type($optional_res)
+ (`[` $variadic^ `]`)? attr-dict
+}]>;
+
+def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
+ (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res)
+ (`[` $variadic^ `]`)? attr-dict
+}]>;
+
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 4ccbd04ef221..eb736b7e6f0b 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -112,6 +112,16 @@ def NS_DOp : NS_Op<"op_with_two_operands", []> {
// CHECK-LABEL: NS::DOp declarations
// CHECK: OpTrait::NOperands<2>::Impl
+def NS_EOp : NS_Op<"op_with_optionals", []> {
+ let arguments = (ins Optional<I32>:$a);
+ let results = (outs Optional<F32>:$b);
+}
+
+// CHECK-LABEL: NS::EOp declarations
+// CHECK: Value a();
+// CHECK: Value b();
+// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
+
// Check that default builders can be suppressed.
// ---
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 482175ba5ea6..613f3d1d4829 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -222,7 +222,7 @@ def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
($attr^) attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
-// CHECK: error: only variadic operands can be used within an optional group
+// CHECK: error: only variable length operands can be used within an optional group
def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
($arg^) attr-dict
}]>, Arguments<(ins I64:$arg)>;
@@ -327,6 +327,17 @@ def ZCoverageInvalidF : TestFormat_Op<"variable_invalid_f", [{
}]> {
let successors = (successor AnySuccessor:$successor);
}
+// CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
+// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
+def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{
+ operands attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
+// CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
+// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
+def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{
+ attr-dict
+}]>, Results<(outs Optional<I64>:$result)>;
+
// CHECK-NOT: error
def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{
$operand type($operand) type($result) attr-dict
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 5e0e484ae3e2..8d55768aced7 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -18,6 +18,10 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
// CHECK: test.format_buildable_type_op %[[I64]]
%ignored = test.format_buildable_type_op %i64
+//===----------------------------------------------------------------------===//
+// Format results
+//===----------------------------------------------------------------------===//
+
// CHECK: test.format_result_a_op memref<1xf64>
%ignored_a:2 = test.format_result_a_op memref<1xf64>
@@ -27,6 +31,10 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
// CHECK: test.format_result_c_op (i64) -> memref<1xf64>
%ignored_c:2 = test.format_result_c_op (i64) -> memref<1xf64>
+//===----------------------------------------------------------------------===//
+// Format operands
+//===----------------------------------------------------------------------===//
+
// CHECK: test.format_operand_a_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_a_op %i64, %memref : i64, memref<1xf64>
@@ -42,6 +50,10 @@ test.format_operand_d_op %i64, %memref : memref<1xf64>
// CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64>
test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
+//===----------------------------------------------------------------------===//
+// Format successors
+//===----------------------------------------------------------------------===//
+
"foo.successor_test_region"() ( {
^bb0:
// CHECK: test.format_successor_a_op ^bb1 {attr}
@@ -57,3 +69,28 @@ test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
}) { arg_names = ["i", "j", "k"] } : () -> ()
+//===----------------------------------------------------------------------===//
+// Format optional operands and results
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : i64
+test.format_optional_operand_result_a_op(%i64 : i64) : i64
+
+// CHECK: test.format_optional_operand_result_a_op( : ) : i64
+test.format_optional_operand_result_a_op( : ) : i64
+
+// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) :
+// CHECK-NOT: i64
+test.format_optional_operand_result_a_op(%i64 : i64) :
+
+// CHECK: test.format_optional_operand_result_a_op(%[[I64]] : i64) : [%[[I64]], %[[I64]]]
+test.format_optional_operand_result_a_op(%i64 : i64) : [%i64, %i64]
+
+// CHECK: test.format_optional_operand_result_b_op(%[[I64]] : i64) : i64
+test.format_optional_operand_result_b_op(%i64 : i64) : i64
+
+// CHECK: test.format_optional_operand_result_b_op : i64
+test.format_optional_operand_result_b_op( : ) : i64
+
+// CHECK: test.format_optional_operand_result_b_op : i64
+test.format_optional_operand_result_b_op : i64
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index c679bd9e9704..829445f7a148 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
}
// CHECK-LABEL: OpA::verify
-// CHECK: for (Value v : getODSOperands(0)) {
+// CHECK: auto valueGroup0 = getODSOperands(0);
+// CHECK: for (Value v : valueGroup0) {
// CHECK: if (!((v.getType().isInteger(32) || v.getType().isF32())))
def OpB : NS_Op<"op_for_And_PredOpTrait", [
@@ -90,5 +91,6 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
}
// CHECK-LABEL: OpK::verify
-// CHECK: for (Value v : getODSOperands(0)) {
+// CHECK: auto valueGroup0 = getODSOperands(0);
+// CHECK: for (Value v : valueGroup0) {
// CHECK: if (!(((v.getType().isa<TensorType>())) && (((v.getType().cast<ShapedType>().getElementType().isF32())) || ((v.getType().cast<ShapedType>().getElementType().isSignlessInteger(32))))))
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index c46e5462a1cd..4049c42a6296 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -75,7 +75,7 @@ static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
- return operand.isVariadic() && operand.name == name;
+ return operand.isVariableLength() && operand.name == name;
}
// Check if `result` is a known name of a result of `op`.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 1d2ee2f0efe9..d3d153578774 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -452,7 +452,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef rangeSizeCall,
StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
- const int numVariadicOperands = op.getNumVariadicOperands();
+ const int numVariadicOperands = op.getNumVariableLengthOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
const auto *sameVariadicSize =
@@ -493,9 +493,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numOperands);
- for (int i = 0; i < numOperands; ++i) {
- isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic()));
- }
+ for (int i = 0; i < numOperands; ++i)
+ isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
+ : "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
@@ -511,11 +511,15 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
if (operand.name.empty())
continue;
- if (operand.isVariadic()) {
+ if (operand.isOptional()) {
+ auto &m = opClass.newMethod("Value", operand.name);
+ m.body() << " auto operands = getODSOperands(" << i << ");\n"
+ << " return operands.empty() ? Value() : *operands.begin();";
+ } else if (operand.isVariadic()) {
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << " return getODSOperands(" << i << ");";
} else {
- auto &m = opClass.newMethod("Value ", operand.name);
+ auto &m = opClass.newMethod("Value", operand.name);
m.body() << " return *getODSOperands(" << i << ").begin();";
}
}
@@ -534,7 +538,7 @@ void OpEmitter::genNamedOperandGetters() {
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
- const int numVariadicResults = op.getNumVariadicResults();
+ const int numVariadicResults = op.getNumVariableLengthResults();
const int numNormalResults = numResults - numVariadicResults;
// If we have more than one variadic results, we need more complicated logic
@@ -573,9 +577,9 @@ void OpEmitter::genNamedResultGetters() {
} else {
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numResults);
- for (int i = 0; i < numResults; ++i) {
- isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic()));
- }
+ for (int i = 0; i < numResults; ++i)
+ isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
+ : "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
@@ -589,11 +593,15 @@ void OpEmitter::genNamedResultGetters() {
if (result.name.empty())
continue;
- if (result.isVariadic()) {
+ if (result.isOptional()) {
+ auto &m = opClass.newMethod("Value", result.name);
+ m.body() << " auto results = getODSResults(" << i << ");\n"
+ << " return results.empty() ? Value() : *results.begin();";
+ } else if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << " return getODSResults(" << i << ");";
} else {
- auto &m = opClass.newMethod("Value ", result.name);
+ auto &m = opClass.newMethod("Value", result.name);
m.body() << " return *getODSResults(" << i << ").begin();";
}
}
@@ -706,6 +714,8 @@ void OpEmitter::genSeparateArgParamBuilder() {
return;
case TypeParamKind::Separate:
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+ if (op.getResult(i).isOptional())
+ body << " if (" << resultNames[i] << ")\n ";
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
@@ -713,12 +723,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
case TypeParamKind::Collective:
body << " "
<< "assert(resultTypes.size() "
- << (op.getNumVariadicResults() == 0 ? "==" : ">=") << " "
- << (op.getNumResults() - op.getNumVariadicResults())
+ << (op.getNumVariableLengthResults() == 0 ? "==" : ">=") << " "
+ << (op.getNumResults() - op.getNumVariableLengthResults())
<< "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
return;
- };
+ }
llvm_unreachable("unhandled TypeParamKind");
};
@@ -731,7 +741,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
// Emit separate arg build with collective type, unless there is only one
// variadic result, in which case the above would have already generated
// the same build method.
- if (!(op.getNumResults() == 1 && op.getResult(0).isVariadic()))
+ if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength()))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
@@ -739,7 +749,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
// If this op has a variadic result, we cannot generate this builder because
// we don't know how many results to create.
- if (op.getNumVariadicResults() != 0)
+ if (op.getNumVariableLengthResults() != 0)
return;
int numResults = op.getNumResults();
@@ -887,7 +897,7 @@ void OpEmitter::genBuilder() {
// 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate
diff erent call patterns.
- if (op.getNumVariadicResults() == 0) {
+ if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder();
@@ -899,11 +909,11 @@ void OpEmitter::genBuilder() {
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
- int numVariadicResults = op.getNumVariadicResults();
+ int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
- int numVariadicOperands = op.getNumVariadicOperands();
+ int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Signature
std::string params = std::string("Builder *, OperationState &") +
@@ -972,7 +982,12 @@ void OpEmitter::buildParamList(std::string ¶mList,
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
- paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
+ if (result.isOptional())
+ paramList.append(", /*optional*/Type ");
+ else if (result.isVariadic())
+ paramList.append(", ArrayRef<Type> ");
+ else
+ paramList.append(", Type ");
paramList.append(resultName);
resultTypeNames.emplace_back(std::move(resultName));
@@ -1018,7 +1033,12 @@ void OpEmitter::buildParamList(std::string ¶mList,
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
- paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value ");
+ if (operand.isOptional())
+ paramList.append(", /*optional*/Value ");
+ else if (operand.isVariadic())
+ paramList.append(", ValueRange ");
+ else
+ paramList.append(", Value ");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
@@ -1076,8 +1096,10 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
- body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i)
- << ");\n";
+ std::string argName = getArgumentName(op, i);
+ if (op.getOperand(i).isOptional())
+ body << " if (" << argName << ")\n ";
+ body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
// If the operation has the operand segment size attribute, add it here.
@@ -1086,7 +1108,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder->getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
- if (op.getOperand(i).isVariadic())
+ if (op.getOperand(i).isOptional())
+ body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
+ else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
body << "1";
@@ -1160,7 +1184,7 @@ void OpEmitter::genCanonicalizerDecls() {
void OpEmitter::genFolderDecls() {
bool hasSingleResult =
- op.getNumResults() == 1 && op.getNumVariadicResults() == 0;
+ op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
@@ -1434,17 +1458,33 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
- if (!staticValue.value().hasPredicate())
+ bool hasPredicate = staticValue.value().hasPredicate();
+ bool isOptional = staticValue.value().isOptional();
+ if (!hasPredicate && !isOptional)
continue;
-
- // Emit a loop to check all the dynamic values in the pack.
- body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n",
+ body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
staticValue.index());
- auto constraint = staticValue.value().constraint;
+ // If the constraint is optional check that the value group has at most 1
+ // value.
+ if (isOptional) {
+ body << formatv(" if (valueGroup{0}.size() > 1)\n"
+ " return emitOpError(\"{1} group starting at #\") "
+ "<< index << \" requires 0 or 1 element, but found \" << "
+ "valueGroup{0}.size();\n",
+ staticValue.index(), valueKind);
+ }
+
+ // Otherwise, if there is no predicate there is nothing left to do.
+ if (!hasPredicate)
+ continue;
+ // Emit a loop to check all the dynamic values in the pack.
+ body << " for (Value v : valueGroup" << staticValue.index() << ") {\n";
+
+ auto constraint = staticValue.value().constraint;
body << " (void)v;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
@@ -1569,7 +1609,7 @@ void OpEmitter::genTraits() {
// Add result size trait.
int numResults = op.getNumResults();
- int numVariadicResults = op.getNumVariadicResults();
+ int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// Add successor size trait.
@@ -1579,7 +1619,7 @@ void OpEmitter::genTraits() {
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
- int numVariadicOperands = op.getNumVariadicOperands();
+ int numVariadicOperands = op.getNumVariableLengthOperands();
// Add operand size trait.
if (numVariadicOperands != 0) {
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 12bffc761583..c54ab243735a 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -395,6 +395,17 @@ const char *const variadicOperandParserCode = R"(
if (parser.parseOperandList({0}Operands))
return failure();
)";
+const char *const optionalOperandParserCode = R"(
+ {
+ OpAsmParser::OperandType operand;
+ OptionalParseResult parseResult = parser.parseOptionalOperand(operand);
+ if (parseResult.hasValue()) {
+ if (failed(*parseResult))
+ return failure();
+ {0}Operands.push_back(operand);
+ }
+ }
+)";
const char *const operandParserCode = R"(
if (parser.parseOperand({0}RawOperands[0]))
return failure();
@@ -407,6 +418,17 @@ const char *const variadicTypeParserCode = R"(
if (parser.parseTypeList({0}Types))
return failure();
)";
+const char *const optionalTypeParserCode = R"(
+ {
+ Type optionalType;
+ OptionalParseResult parseResult = parser.parseOptionalType(optionalType);
+ if (parseResult.hasValue()) {
+ if (failed(*parseResult))
+ return failure();
+ {0}Types.push_back(optionalType);
+ }
+ }
+)";
const char *const typeParserCode = R"(
if (parser.parseType({0}RawTypes[0]))
return failure();
@@ -456,18 +478,40 @@ const char *successorParserCode = R"(
return failure();
)";
+namespace {
+/// The type of length for a given parse argument.
+enum class ArgumentLengthKind {
+ /// The argument is variadic, and may contain 0->N elements.
+ Variadic,
+ /// The argument is optional, and may contain 0 or 1 elements.
+ Optional,
+ /// The argument is a single element, i.e. always represents 1 element.
+ Single
+};
+} // end anonymous namespace
+
+/// Get the length kind for the given constraint.
+static ArgumentLengthKind
+getArgumentLengthKind(const NamedTypeConstraint *var) {
+ if (var->isOptional())
+ return ArgumentLengthKind::Optional;
+ if (var->isVariadic())
+ return ArgumentLengthKind::Variadic;
+ return ArgumentLengthKind::Single;
+}
+
/// Get the name used for the type list for the given type directive operand.
-/// 'isVariadic' is set to true if the operand has variadic types.
-static StringRef getTypeListName(Element *arg, bool &isVariadic) {
+/// 'lengthKind' to the corresponding kind for the given argument.
+static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
if (auto *operand = dyn_cast<OperandVariable>(arg)) {
- isVariadic = operand->getVar()->isVariadic();
+ lengthKind = getArgumentLengthKind(operand->getVar());
return operand->getVar()->name;
}
if (auto *result = dyn_cast<ResultVariable>(arg)) {
- isVariadic = result->getVar()->isVariadic();
+ lengthKind = getArgumentLengthKind(result->getVar());
return result->getVar()->name;
}
- isVariadic = true;
+ lengthKind = ArgumentLengthKind::Variadic;
if (isa<OperandsDirective>(arg))
return "allOperand";
if (isa<ResultsDirective>(arg))
@@ -502,7 +546,7 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
genElementParserStorage(&childElement, body);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
- if (operand->getVar()->isVariadic()) {
+ if (operand->getVar()->isVariableLength()) {
body << " SmallVector<OpAsmParser::OperandType, 4> " << name
<< "Operands;\n";
} else {
@@ -515,15 +559,15 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
" (void){0}OperandsLoc;\n",
name);
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
- bool variadic = false;
- StringRef name = getTypeListName(dir->getOperand(), variadic);
- if (variadic)
+ ArgumentLengthKind lengthKind;
+ StringRef name = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind != ArgumentLengthKind::Single)
body << " SmallVector<Type, 1> " << name << "Types;\n";
else
body << llvm::formatv(" Type {0}RawTypes[1];\n", name)
<< llvm::formatv(" ArrayRef<Type> {0}Types({0}RawTypes);\n", name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
- bool ignored = false;
+ ArgumentLengthKind ignored;
body << " ArrayRef<Type> " << getTypeListName(dir->getInputs(), ignored)
<< "Types;\n";
body << " ArrayRef<Type> " << getTypeListName(dir->getResults(), ignored)
@@ -592,9 +636,14 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
- bool isVariadic = operand->getVar()->isVariadic();
- body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
- operand->getVar()->name);
+ ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
+ StringRef name = operand->getVar()->name;
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv(variadicOperandParserCode, name);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(optionalOperandParserCode, name);
+ else
+ body << formatv(operandParserCode, name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
@@ -614,12 +663,16 @@ static void genElementParser(Element *element, OpMethodBody &body,
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
- bool isVariadic = false;
- StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
- body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode,
- listName);
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv(variadicTypeParserCode, listName);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(optionalTypeParserCode, listName);
+ else
+ body << formatv(typeParserCode, listName);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
- bool ignored = false;
+ ArgumentLengthKind ignored;
body << formatv(functionalTypeParserCode,
getTypeListName(dir->getInputs(), ignored),
getTypeListName(dir->getResults(), ignored));
@@ -817,7 +870,7 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
<< "builder.getI32VectorAttr({";
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
- if (operand.isVariadic())
+ if (operand.isVariableLength())
body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
else
body << "1";
@@ -885,6 +938,10 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
if (var->isVariadic())
return body << var->name << "().getTypes()";
+ if (var->isOptional())
+ return body << llvm::formatv(
+ "({0}() ? ArrayRef<Type>({0}().getType()) : ArrayRef<Type>())",
+ var->name);
return body << "ArrayRef<Type>(" << var->name << "().getType())";
}
@@ -900,11 +957,16 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
Element *anchor = optional->getAnchor();
- if (AttributeVariable *attrVar = dyn_cast<AttributeVariable>(anchor))
- body << " if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n";
- else
- body << " if (!" << cast<OperandVariable>(anchor)->getVar()->name
- << "().empty()) {\n";
+ if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
+ const NamedTypeConstraint *var = operand->getVar();
+ if (var->isOptional())
+ body << " if (" << var->name << "()) {\n";
+ else if (var->isVariadic())
+ body << " if (!" << var->name << "().empty()) {\n";
+ } else {
+ body << " if (getAttr(\""
+ << cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
+ }
// Emit each of the elements.
for (Element &childElement : optional->getElements())
@@ -945,7 +1007,12 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
- body << " p << " << operand->getVar()->name << "();\n";
+ if (operand->getVar()->isOptional()) {
+ body << " if (Value value = " << operand->getVar()->name << "())\n"
+ << " p << value;\n";
+ } else {
+ body << " p << " << operand->getVar()->name << "();\n";
+ }
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic())
@@ -1521,14 +1588,12 @@ LogicalResult FormatParser::verifyOperands(
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
- if (!builder || (hasAllOperands && operand.isVariadic())) {
+ if (!builder || (hasAllOperands && operand.isVariableLength())) {
return emitErrorAndNote(
loc,
"type of operand #" + Twine(i) + ", named '" + operand.name +
- "', is not buildable and a buildable " +
- "type cannot be inferred",
- "suggest adding a type constraint "
- "to the operation or adding a "
+ "', is not buildable and a buildable type cannot be inferred",
+ "suggest adding a type constraint to the operation or adding a "
"'type($" +
operand.name + ")' directive to the " + "custom assembly format");
}
@@ -1559,18 +1624,16 @@ LogicalResult FormatParser::verifyResults(
continue;
}
- // If the result is not variadic, allow for the case where the type has a
- // builder that we can use.
+ // If the result is not variable length, allow for the case where the type
+ // has a builder that we can use.
NamedTypeConstraint &result = op.getResult(i);
Optional<StringRef> builder = result.constraint.getBuilderCall();
- if (!builder || result.constraint.isVariadic()) {
+ if (!builder || result.isVariableLength()) {
return emitErrorAndNote(
loc,
"type of result #" + Twine(i) + ", named '" + result.name +
- "', is not buildable and a buildable " +
- "type cannot be inferred",
- "suggest adding a type constraint "
- "to the operation or adding a "
+ "', is not buildable and a buildable type cannot be inferred",
+ "suggest adding a type constraint to the operation or adding a "
"'type($" +
result.name + ")' directive to the " + "custom assembly format");
}
@@ -1842,9 +1905,9 @@ LogicalResult FormatParser::parseOptionalChildElement(
// Only optional-like(i.e. variadic) operands can be within an optional
// group.
.Case<OperandVariable>([&](OperandVariable *ele) {
- if (!ele->getVar()->isVariadic())
- return emitError(childLoc, "only variadic operands can be used within"
- " an optional group");
+ if (!ele->getVar()->isVariableLength())
+ return emitError(childLoc, "only variable length operands can be "
+ "used within an optional group");
seenVariables.insert(ele->getVar());
return success();
})
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index c61a99e53b9d..cf705ed02cd7 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -243,7 +243,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
- if (operand->isVariadic()) {
+ if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
op.getOperationName(), i);
@@ -296,7 +296,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
// of op definition.
Constraint constraint = matcher.getAsConstraint();
if (operand->constraint != constraint) {
- if (operand->isVariadic()) {
+ if (operand->isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index a96f2bca8a35..d68d9eeba634 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -807,11 +807,11 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
- if (valueArg->isVariadic()) {
+ if (valueArg->isVariableLength()) {
if (i != e - 1) {
- PrintFatalError(loc,
- "SPIR-V ops can have Variadic<..> argument only if "
- "it's the last argument");
+ PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
+ "Optional<...> arguments only if "
+ "it's the last argument");
}
os << tabs
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
@@ -829,7 +829,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
words, wordIndex);
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(arg);\n", operands);
- if (!valueArg->isVariadic()) {
+ if (!valueArg->isVariableLength()) {
os << tabs << formatv(" {0}++;\n", wordIndex);
}
operandNum++;
More information about the Mlir-commits
mailing list