[Mlir-commits] [mlir] 225f11c - [mlir] Partially revert removal of old `fold` method

Markus Böck llvmlistbot at llvm.org
Wed Feb 22 16:06:53 PST 2023


Author: Markus Böck
Date: 2023-02-23T00:55:35+01:00
New Revision: 225f11cff7fb1983cd849fbd253c459a804ce525

URL: https://github.com/llvm/llvm-project/commit/225f11cff7fb1983cd849fbd253c459a804ce525
DIFF: https://github.com/llvm/llvm-project/commit/225f11cff7fb1983cd849fbd253c459a804ce525.diff

LOG: [mlir] Partially revert removal of old `fold` method

Mehdi noted in https://reviews.llvm.org/D144391 that given the low cost of keeping the old `fold` method signature working and the difficulty of writing a `FoldAdaptor` oneself, it'd be nice to keep the support for the sake of Ops written manually in C++.
This patch therefore partially reverts the removal of the old `fold` method by still allowing the old signature to be used. The active use of it is still discouraged and ODS will always generate the new method using `FoldAdaptor`s.

I'd also like to note that the previous ought to have broken some manually defined `fold` methods in-tree that are defined here: https://github.com/llvm/llvm-project/blob/23bcd6b86271f1c219a69183a5d90654faca64b8/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h#L245 It seems like these are not part of the regressions tests however...

Differential Revision: https://reviews.llvm.org/D144591

Added: 
    mlir/test/IR/test-manual-cpp-fold.mlir

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index fe2bdd5386439..f7d8436dcc1c9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1688,17 +1688,33 @@ class Op : public OpState, public Traits<ConcreteType>... {
   /// Trait to check if T provides a 'fold' method for a single result op.
   template <typename T, typename... Args>
   using has_single_result_fold_t =
-      decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+      decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
   template <typename T>
   constexpr static bool has_single_result_fold_v =
       llvm::is_detected<has_single_result_fold_t, T>::value;
   /// Trait to check if T provides a general 'fold' method.
   template <typename T, typename... Args>
   using has_fold_t = decltype(std::declval<T>().fold(
-      std::declval<typename T::FoldAdaptor>(),
+      std::declval<ArrayRef<Attribute>>(),
       std::declval<SmallVectorImpl<OpFoldResult> &>()));
   template <typename T>
   constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
+  /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
+  /// single result op.
+  template <typename T, typename... Args>
+  using has_fold_adaptor_single_result_fold_t =
+      decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+  template <class T>
+  constexpr static bool has_fold_adaptor_single_result_v =
+      llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
+  /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
+  template <typename T, typename... Args>
+  using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
+      std::declval<typename T::FoldAdaptor>(),
+      std::declval<SmallVectorImpl<OpFoldResult> &>()));
+  template <class T>
+  constexpr static bool has_fold_adaptor_v =
+      llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
 
   /// Trait to check if T provides a 'print' method.
   template <typename T, typename... Args>
@@ -1748,13 +1764,14 @@ class Op : public OpState, public Traits<ConcreteType>... {
     // If the operation is single result and defines a `fold` method.
     if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
                                   Traits<ConcreteType>...>::value &&
-                  has_single_result_fold_v<ConcreteType>)
+                  (has_single_result_fold_v<ConcreteType> ||
+                   has_fold_adaptor_single_result_v<ConcreteType>))
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldSingleResultHook<ConcreteType>(op, operands, results);
       };
     // The operation is not single result and defines a `fold` method.
-    if constexpr (has_fold_v<ConcreteType>)
+    if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldHook<ConcreteType>(op, operands, results);
@@ -1773,9 +1790,12 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static LogicalResult
   foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
                        SmallVectorImpl<OpFoldResult> &results) {
-    OpFoldResult result =
-        cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
-            operands, op->getAttrDictionary(), op->getRegions()));
+    OpFoldResult result;
+    if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
+      result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+          operands, op->getAttrDictionary(), op->getRegions()));
+    else
+      result = cast<ConcreteOpT>(op).fold(operands);
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
@@ -1792,10 +1812,15 @@ class Op : public OpState, public Traits<ConcreteType>... {
   template <typename ConcreteOpT>
   static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
                                 SmallVectorImpl<OpFoldResult> &results) {
-    LogicalResult result = cast<ConcreteOpT>(op).fold(
-        typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
-                                          op->getRegions()),
-        results);
+    auto result = LogicalResult::failure();
+    if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
+      result = cast<ConcreteOpT>(op).fold(
+          typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+                                            op->getRegions()),
+          results);
+    } else {
+      result = cast<ConcreteOpT>(op).fold(operands, results);
+    }
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.

diff  --git a/mlir/test/IR/test-manual-cpp-fold.mlir b/mlir/test/IR/test-manual-cpp-fold.mlir
new file mode 100644
index 0000000000000..592b949f0a139
--- /dev/null
+++ b/mlir/test/IR/test-manual-cpp-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func.func @test() -> i32 {
+  %c5 = "test.constant"() {value = 5 : i32} : () -> i32
+  %res = "test.manual_cpp_op_with_fold"(%c5) : (i32) -> i32
+  return %res : i32
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32}
+// CHECK-NEXT: return %[[C]]

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e62d5a81c84d0..dc5f629610d90 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -358,6 +358,7 @@ void TestDialect::initialize() {
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"
       >();
+  addOperations<ManualCppOpWithFold>();
   registerDynamicOp(getDynamicGenericOp(this));
   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
@@ -1634,6 +1635,14 @@ void TestReflectBoundsOp::inferResultRanges(
   setResultRanges(getResult(), range);
 }
 
+OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
+  // Just a simple fold for testing purposes that reads an operands constant
+  // value and returns it.
+  if (!attributes.empty())
+    return attributes.front();
+  return nullptr;
+}
+
 #include "TestOpEnums.cpp.inc"
 #include "TestOpInterfaces.cpp.inc"
 #include "TestTypeInterfaces.cpp.inc"

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index ceb9dc6f4c933..ad3ef2a9f1cdd 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -58,6 +58,23 @@ class RewritePatternSet;
 #include "TestOps.h.inc"
 
 namespace test {
+
+// Op deliberately defined in C++ code rather than ODS to test that C++
+// Ops can still use the old `fold` method.
+class ManualCppOpWithFold
+    : public mlir::Op<ManualCppOpWithFold, mlir::OpTrait::OneResult> {
+public:
+  using Op::Op;
+
+  static llvm::StringRef getOperationName() {
+    return "test.manual_cpp_op_with_fold";
+  }
+
+  static llvm::ArrayRef<llvm::StringRef> getAttributeNames() { return {}; }
+
+  mlir::OpFoldResult fold(llvm::ArrayRef<mlir::Attribute> attributes);
+};
+
 void registerTestDialect(::mlir::DialectRegistry &registry);
 void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns);
 } // namespace test


        


More information about the Mlir-commits mailing list