[Mlir-commits] [mlir] 225f11c - [mlir] Partially revert removal of old `fold` method
Markus Böck
llvmlistbot at llvm.org
Wed Feb 22 16:06:53 PST 2023
Author: Markus Böck
Date: 2023-02-23T00:55:35+01:00
New Revision: 225f11cff7fb1983cd849fbd253c459a804ce525
URL: https://github.com/llvm/llvm-project/commit/225f11cff7fb1983cd849fbd253c459a804ce525
DIFF: https://github.com/llvm/llvm-project/commit/225f11cff7fb1983cd849fbd253c459a804ce525.diff
LOG: [mlir] Partially revert removal of old `fold` method
Mehdi noted in https://reviews.llvm.org/D144391 that given the low cost of keeping the old `fold` method signature working and the difficulty of writing a `FoldAdaptor` oneself, it'd be nice to keep the support for the sake of Ops written manually in C++.
This patch therefore partially reverts the removal of the old `fold` method by still allowing the old signature to be used. The active use of it is still discouraged and ODS will always generate the new method using `FoldAdaptor`s.
I'd also like to note that the previous ought to have broken some manually defined `fold` methods in-tree that are defined here: https://github.com/llvm/llvm-project/blob/23bcd6b86271f1c219a69183a5d90654faca64b8/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h#L245 It seems like these are not part of the regressions tests however...
Differential Revision: https://reviews.llvm.org/D144591
Added:
mlir/test/IR/test-manual-cpp-fold.mlir
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index fe2bdd5386439..f7d8436dcc1c9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1688,17 +1688,33 @@ class Op : public OpState, public Traits<ConcreteType>... {
/// Trait to check if T provides a 'fold' method for a single result op.
template <typename T, typename... Args>
using has_single_result_fold_t =
- decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+ decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
template <typename T>
constexpr static bool has_single_result_fold_v =
llvm::is_detected<has_single_result_fold_t, T>::value;
/// Trait to check if T provides a general 'fold' method.
template <typename T, typename... Args>
using has_fold_t = decltype(std::declval<T>().fold(
- std::declval<typename T::FoldAdaptor>(),
+ std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
+ /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
+ /// single result op.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_single_result_fold_t =
+ decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_single_result_v =
+ llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
+ /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
+ template <typename T, typename... Args>
+ using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
+ std::declval<typename T::FoldAdaptor>(),
+ std::declval<SmallVectorImpl<OpFoldResult> &>()));
+ template <class T>
+ constexpr static bool has_fold_adaptor_v =
+ llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
@@ -1748,13 +1764,14 @@ class Op : public OpState, public Traits<ConcreteType>... {
// If the operation is single result and defines a `fold` method.
if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
Traits<ConcreteType>...>::value &&
- has_single_result_fold_v<ConcreteType>)
+ (has_single_result_fold_v<ConcreteType> ||
+ has_fold_adaptor_single_result_v<ConcreteType>))
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldSingleResultHook<ConcreteType>(op, operands, results);
};
// The operation is not single result and defines a `fold` method.
- if constexpr (has_fold_v<ConcreteType>)
+ if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldHook<ConcreteType>(op, operands, results);
@@ -1773,9 +1790,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
static LogicalResult
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- OpFoldResult result =
- cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getAttrDictionary(), op->getRegions()));
+ OpFoldResult result;
+ if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
+ result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+ operands, op->getAttrDictionary(), op->getRegions()));
+ else
+ result = cast<ConcreteOpT>(op).fold(operands);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
@@ -1792,10 +1812,15 @@ class Op : public OpState, public Traits<ConcreteType>... {
template <typename ConcreteOpT>
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- LogicalResult result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
- op->getRegions()),
- results);
+ auto result = LogicalResult::failure();
+ if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+ op->getRegions()),
+ results);
+ } else {
+ result = cast<ConcreteOpT>(op).fold(operands, results);
+ }
// If the fold failed or was in-place, try to fold the traits of the
// operation.
diff --git a/mlir/test/IR/test-manual-cpp-fold.mlir b/mlir/test/IR/test-manual-cpp-fold.mlir
new file mode 100644
index 0000000000000..592b949f0a139
--- /dev/null
+++ b/mlir/test/IR/test-manual-cpp-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func.func @test() -> i32 {
+ %c5 = "test.constant"() {value = 5 : i32} : () -> i32
+ %res = "test.manual_cpp_op_with_fold"(%c5) : (i32) -> i32
+ return %res : i32
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32}
+// CHECK-NEXT: return %[[C]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e62d5a81c84d0..dc5f629610d90 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -358,6 +358,7 @@ void TestDialect::initialize() {
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
+ addOperations<ManualCppOpWithFold>();
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
@@ -1634,6 +1635,14 @@ void TestReflectBoundsOp::inferResultRanges(
setResultRanges(getResult(), range);
}
+OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
+ // Just a simple fold for testing purposes that reads an operands constant
+ // value and returns it.
+ if (!attributes.empty())
+ return attributes.front();
+ return nullptr;
+}
+
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestTypeInterfaces.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index ceb9dc6f4c933..ad3ef2a9f1cdd 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -58,6 +58,23 @@ class RewritePatternSet;
#include "TestOps.h.inc"
namespace test {
+
+// Op deliberately defined in C++ code rather than ODS to test that C++
+// Ops can still use the old `fold` method.
+class ManualCppOpWithFold
+ : public mlir::Op<ManualCppOpWithFold, mlir::OpTrait::OneResult> {
+public:
+ using Op::Op;
+
+ static llvm::StringRef getOperationName() {
+ return "test.manual_cpp_op_with_fold";
+ }
+
+ static llvm::ArrayRef<llvm::StringRef> getAttributeNames() { return {}; }
+
+ mlir::OpFoldResult fold(llvm::ArrayRef<mlir::Attribute> attributes);
+};
+
void registerTestDialect(::mlir::DialectRegistry ®istry);
void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns);
} // namespace test
More information about the Mlir-commits
mailing list