[Mlir-commits] [mlir] bbfa7ef - [mlir] Add a new fold API using Generic Adaptors
Markus Böck
llvmlistbot at llvm.org
Wed Jan 11 05:34:18 PST 2023
Author: Markus Böck
Date: 2023-01-11T14:32:21+01:00
New Revision: bbfa7ef16dd9900b36abfa1a5f2faddb81afeb51
URL: https://github.com/llvm/llvm-project/commit/bbfa7ef16dd9900b36abfa1a5f2faddb81afeb51
DIFF: https://github.com/llvm/llvm-project/commit/bbfa7ef16dd9900b36abfa1a5f2faddb81afeb51.diff
LOG: [mlir] Add a new fold API using Generic Adaptors
This is part of the RFC for a better fold API: https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374
This patch implements the required foldHook changes and the TableGen machinery for generating `fold` method signatures using `FoldAdaptor` for ops, based on the value of `useFoldAPI` of the dialect. It may be one of 2 values, with convenient named constants to create a quasi enum. The new `fold` method will then be generated if `kEmitFoldAdaptorFolder` is used.
Since the new `FoldAdaptor` approach is strictly better than the old signature, part of this patch updates the documentation and all example to encourage use of the new `fold` signature.
Included are also tests exercising the new API, ensuring proper construction of the `FoldAdaptor` and proper generation by TableGen.
Differential Revision: https://reviews.llvm.org/D140886
Added:
mlir/test/IR/test-fold-adaptor.mlir
mlir/test/mlir-tblgen/has-fold-invalid-values.td
Modified:
mlir/docs/Canonicalization.md
mlir/docs/DefiningDialects/_index.md
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/IR/DialectBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/TableGen/Dialect.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Dialect.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index d1aed547eff9d..d1cba572af212 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -156,7 +156,7 @@ If the operation has a single result the following will be generated:
/// of the operation. The caller will remove the operation and use that
/// result instead.
///
-OpFoldResult MyOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MyOp::fold(FoldAdaptor adaptor) {
...
}
```
@@ -178,19 +178,19 @@ Otherwise, the following is generated:
/// the operation and use those results instead.
///
/// Note that this mechanism cannot be used to remove 0-result operations.
-LogicalResult MyOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
...
}
```
-In the above, for each method an `ArrayRef<Attribute>` is provided that
-corresponds to the constant attribute value of each of the operands. These
+In the above, for each method a `FoldAdaptor` is provided with getters for
+each of the operands, returning the corresponding constant attribute. These
operands are those that implement the `ConstantLike` trait. If any of the
operands are non-constant, a null `Attribute` value is provided instead. For
example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is
-constant then `operands` will be of the form [Attribute(), b-value,
-Attribute()].
+constant then `adaptor` will return Attribute() for `getA()` and `getC()`,
+and b-value for `getB()`.
Also above, is the use of `OpFoldResult`. This class represents the possible
result of folding an operation result: either an SSA `Value`, or an
diff --git a/mlir/docs/DefiningDialects/_index.md b/mlir/docs/DefiningDialects/_index.md
index 0f2b9455d7723..ca07c33dc7e94 100644
--- a/mlir/docs/DefiningDialects/_index.md
+++ b/mlir/docs/DefiningDialects/_index.md
@@ -255,6 +255,31 @@ LogicalResult MyDialect::verifyRegionResultAttribute(Operation *op, unsigned reg
unsigned argIndex, NamedAttribute attribute);
```
+#### `useFoldAPI`
+
+There are currently two possible values that are allowed to be assigned to this
+field:
+* `kEmitFoldAdaptorFolder` generates a `fold` method making use of the op's
+ `FoldAdaptor` to allow access of operands via convenient getter.
+
+ Generated code example:
+ ```cpp
+ OpFoldResult fold(FoldAdaptor adaptor);
+ // or
+ LogicalResult fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult>& results);
+ ```
+* `kEmitRawAttributesFolder` generates the deprecated legacy `fold`
+ method, containing `ArrayRef<Attribute>` in the parameter list instead of
+ the op's `FoldAdaptor`. This API is scheduled for removal and should not be
+ used by new dialects.
+
+ Generated code example:
+ ```cpp
+ OpFoldResult fold(ArrayRef<Attribute> operands);
+ // or
+ LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult>& results);
+ ```
+
### Operation Interface Fallback
Some dialects have an open ecosystem and don't register all of the possible operations. In such
diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 6148b92dadd85..2114bf6e039fb 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -458,16 +458,16 @@ method.
```c++
/// Fold constants.
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return value(); }
/// Fold struct constants.
-OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) {
return value();
}
/// Fold simple struct access operations that access into a constant.
-OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
- auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
+ auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
if (!structAttr)
return nullptr;
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 08671a7347c19..504d316f2f937 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -33,6 +33,8 @@ def Toy_Dialect : Dialect {
// We set this bit to generate the declarations for the dialect's type parsing
// and printing hooks.
let useDefaultTypePrinterParser = 1;
+
+ let useFoldAPI = kEmitFoldAdaptorFolder;
}
// Base class for toy dialect operations. This operation inherits from the base
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 36ba04bdc5c41..62b00d99476a0 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -24,18 +24,14 @@ namespace {
} // namespace
/// Fold constants.
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
- return getValue();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold struct constants.
-OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
- return getValue();
-}
+OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold simple struct access operations that access into a constant.
-OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
- auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
+ auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
if (!structAttr)
return nullptr;
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index ab14e3405e5df..043019e60e9b5 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -17,6 +17,14 @@
// Dialect definitions
//===----------------------------------------------------------------------===//
+// Generate 'fold' method with 'ArrayRef<Attribute>' parameter.
+// New code should prefer using 'kEmitFoldAdaptorFolder' and
+// consider 'kEmitRawAttributesFolder' deprecated and to be
+// removed in the future.
+defvar kEmitRawAttributesFolder = 0;
+// Generate 'fold' method with 'FoldAdaptor' parameter.
+defvar kEmitFoldAdaptorFolder = 1;
+
class Dialect {
// The name of the dialect.
string name = ?;
@@ -85,6 +93,9 @@ class Dialect {
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
+
+ // Fold API to use for operations in this dialect.
+ int useFoldAPI = kEmitRawAttributesFolder;
}
#endif // DIALECTBASE_TD
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c185750e5bcee..96b6e174f5081 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1686,18 +1686,35 @@ class Op : public OpState, public Traits<ConcreteType>... {
private:
/// Trait to check if T provides a 'fold' method for a single result op.
template <typename T, typename... Args>
- using has_single_result_fold =
+ using has_single_result_fold_t =
decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
template <typename T>
- using detect_has_single_result_fold =
- llvm::is_detected<has_single_result_fold, T>;
+ constexpr static bool has_single_result_fold_v =
+ llvm::is_detected<has_single_result_fold_t, T>::value;
/// Trait to check if T provides a general 'fold' method.
template <typename T, typename... Args>
- using has_fold = decltype(std::declval<T>().fold(
+ using has_fold_t = decltype(std::declval<T>().fold(
std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
- using detect_has_fold = llvm::is_detected<has_fold, T>;
+ constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
+ /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
+ /// single result op.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_single_result_fold_t =
+ decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_single_result_v =
+ llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
+ /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
+ std::declval<typename T::FoldAdaptor>(),
+ std::declval<SmallVectorImpl<OpFoldResult> &>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_v =
+ llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
+
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
using has_print =
@@ -1746,13 +1763,14 @@ class Op : public OpState, public Traits<ConcreteType>... {
// If the operation is single result and defines a `fold` method.
if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
Traits<ConcreteType>...>::value &&
- detect_has_single_result_fold<ConcreteType>::value)
+ (has_single_result_fold_v<ConcreteType> ||
+ has_fold_adaptor_single_result_v<ConcreteType>))
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldSingleResultHook<ConcreteType>(op, operands, results);
};
// The operation is not single result and defines a `fold` method.
- if constexpr (detect_has_fold<ConcreteType>::value)
+ if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldHook<ConcreteType>(op, operands, results);
@@ -1771,7 +1789,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
static LogicalResult
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
+ OpFoldResult result;
+ if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
+ result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(), op->getRegions()));
+ else
+ result = cast<ConcreteOpT>(op).fold(operands);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
@@ -1788,7 +1811,15 @@ class Op : public OpState, public Traits<ConcreteType>... {
template <typename ConcreteOpT>
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
+ auto result = LogicalResult::failure();
+ if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+ op->getRegions()),
+ results);
+ } else {
+ result = cast<ConcreteOpT>(op).fold(operands, results);
+ }
// If the fold failed or was in-place, try to fold the traits of the
// operation.
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index d85342c742e2e..8fd519fa63681 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -86,6 +86,15 @@ class Dialect {
/// operations or types.
bool isExtensible() const;
+ enum class FolderAPI {
+ RawAttributes = 0, /// fold method with ArrayRef<Attribute>.
+ FolderAdaptor = 1, /// fold method with the operation's FoldAdaptor.
+ };
+
+ /// Returns the folder API that should be emitted for operations in this
+ /// dialect.
+ FolderAPI getFolderAPI() const;
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 99c3eed731b17..f4a475d60700b 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -314,6 +314,8 @@ class Operator {
/// Returns the remove name for the accessor of `name`.
std::string getRemoverName(StringRef name) const;
+ bool hasFolder() const;
+
private:
/// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 2bbca963ae4d5..8d6b047f14ffd 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -102,6 +102,16 @@ bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}
+Dialect::FolderAPI Dialect::getFolderAPI() const {
+ int64_t value = def->getValueAsInt("useFoldAPI");
+ if (value < static_cast<int64_t>(FolderAPI::RawAttributes) ||
+ value > static_cast<int64_t>(FolderAPI::FolderAdaptor))
+ llvm::PrintFatalError(def->getLoc(),
+ "Invalid value for dialect field `useFoldAPI`");
+
+ return static_cast<FolderAPI>(value);
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 44177052aa61c..150c385fd2d00 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -745,3 +745,5 @@ std::string Operator::getSetterName(StringRef name) const {
std::string Operator::getRemoverName(StringRef name) const {
return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
}
+
+bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); }
diff --git a/mlir/test/IR/test-fold-adaptor.mlir b/mlir/test/IR/test-fold-adaptor.mlir
new file mode 100644
index 0000000000000..7815e729f55ad
--- /dev/null
+++ b/mlir/test/IR/test-fold-adaptor.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func.func @test() -> i32 {
+ %c5 = "test.constant"() {value = 5 : i32} : () -> i32
+ %c1 = "test.constant"() {value = 1 : i32} : () -> i32
+ %c2 = "test.constant"() {value = 2 : i32} : () -> i32
+ %c3 = "test.constant"() {value = 3 : i32} : () -> i32
+ %res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } {
+ %c0 = "test.constant"() {value = 0 : i32} : () -> i32
+ }
+ return %res : i32
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
+// CHECK-NEXT: return %[[C]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e710f03bc251d..0f48bfcf6b229 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -33,6 +33,8 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
+#include <numeric>
+
// Include this before the using namespace lines below to
// test that we don't have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
@@ -1126,6 +1128,25 @@ OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
return getOperand();
}
+OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
+ int64_t sum = 0;
+ if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
+ sum += value.getValue().getSExtValue();
+
+ for (Attribute attr : adaptor.getVariadic())
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 2 * value.getValue().getSExtValue();
+
+ for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
+ for (Attribute attr : attrs)
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 3 * value.getValue().getSExtValue();
+
+ sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
+
+ return IntegerAttr::get(getType(), sum);
+}
+
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d027db4863a21..7d661368c0e82 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1297,6 +1297,31 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
}];
}
+def TestOpFoldWithFoldAdaptor
+ : TEST_Op<"fold_with_fold_adaptor",
+ [AttrSizedOperandSegments, NoTerminator]> {
+ let arguments = (ins
+ I32:$op,
+ DenseI32ArrayAttr:$attr,
+ Variadic<I32>:$variadic,
+ VariadicOfVariadic<I32, "attr">:$var_of_var
+ );
+
+ let results = (outs I32:$res);
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
+ }];
+
+ let hasFolder = 0;
+
+ let extraClassDeclaration = [{
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
+ }];
+}
+
// An op that always fold itself.
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let arguments = (ins AnyType:$op);
diff --git a/mlir/test/mlir-tblgen/has-fold-invalid-values.td b/mlir/test/mlir-tblgen/has-fold-invalid-values.td
new file mode 100644
index 0000000000000..09149a534ab60
--- /dev/null
+++ b/mlir/test/mlir-tblgen/has-fold-invalid-values.td
@@ -0,0 +1,15 @@
+// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "NS";
+ let useFoldAPI = 3;
+}
+
+def InvalidValue_Op : Op<Test_Dialect, "invalid_op"> {
+ let hasFolder = 1;
+}
+
+// CHECK: Invalid value for dialect field `useFoldAPI`
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 3a5af1268fcfe..6f9d24dc07c30 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -317,6 +317,29 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+def TestWithNewFold_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "::mlir::testWithFold";
+ let useFoldAPI = kEmitFoldAdaptorFolder;
+}
+
+def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor_fold", []> {
+ let results = (outs AnyType:$res);
+
+ let hasFolder = 1;
+}
+
+// CHECK-LABEL: class MOp :
+// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
+
+def NS_NOp : Op<TestWithNewFold_Dialect, "op_with_multiple_results_and_fold_adaptor_fold", []> {
+ let results = (outs AnyType:$res1, AnyType:$res2);
+
+ let hasFolder = 1;
+}
+
+// CHECK-LABEL: class NOp :
+// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
// Test that type defs have the proper namespaces when used as a constraint.
// ---
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 3b45bb5b1cb96..2483378f691bb 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2326,25 +2326,29 @@ void OpEmitter::genCanonicalizerDecls() {
}
void OpEmitter::genFolderDecls() {
+ if (!op.hasFolder())
+ return;
+
+ Dialect::FolderAPI folderApi = op.getDialect().getFolderAPI();
+ SmallVector<MethodParameter> paramList;
+ if (folderApi == Dialect::FolderAPI::RawAttributes)
+ paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
+ else
+ paramList.emplace_back("FoldAdaptor", "adaptor");
+
+ StringRef retType;
bool hasSingleResult =
op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
-
- if (def.getValueAsBit("hasFolder")) {
- if (hasSingleResult) {
- auto *m = opClass.declareMethod(
- "::mlir::OpFoldResult", "fold",
- MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
- ERROR_IF_PRUNED(m, "operands", op);
- } else {
- SmallVector<MethodParameter> paramList;
- paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
- paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
- "results");
- auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
- std::move(paramList));
- ERROR_IF_PRUNED(m, "fold", op);
- }
+ if (hasSingleResult) {
+ retType = "::mlir::OpFoldResult";
+ } else {
+ paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
+ "results");
+ retType = "::mlir::LogicalResult";
}
+
+ auto *m = opClass.declareMethod(retType, "fold", std::move(paramList));
+ ERROR_IF_PRUNED(m, "fold", op);
}
void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
More information about the Mlir-commits
mailing list