[Mlir-commits] [mlir] 475bbea - [mlir] Complety remove old `fold` API

Markus Böck llvmlistbot at llvm.org
Wed Feb 22 12:20:26 PST 2023


Author: Markus Böck
Date: 2023-02-22T21:20:09+01:00
New Revision: 475bbea5be04146bb9cdd3e0f3c7bc328c744d50

URL: https://github.com/llvm/llvm-project/commit/475bbea5be04146bb9cdd3e0f3c7bc328c744d50
DIFF: https://github.com/llvm/llvm-project/commit/475bbea5be04146bb9cdd3e0f3c7bc328c744d50.diff

LOG: [mlir] Complety remove old `fold` API

Last part of https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374

All active users that I am aware of have already switched. Any remaining users will be forced to adopt their code after this patch has landed.

Differential Revision: https://reviews.llvm.org/D144391

Added: 
    

Modified: 
    mlir/include/mlir/IR/DialectBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/TableGen/Dialect.h
    mlir/lib/TableGen/Dialect.cpp
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    mlir/test/mlir-tblgen/has-fold-invalid-values.td


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 4b8e6c6e3b461..3d665da7f04f9 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -35,19 +35,6 @@ class CppDeprecated<string reason> {
 // Dialect definitions
 //===----------------------------------------------------------------------===//
 
-class EmitFolderBase;
-// Generate 'fold' method with 'ArrayRef<Attribute>' parameter.
-// New code should prefer using 'kEmitFoldAdaptorFolder' and
-// consider 'kEmitRawAttributesFolder' deprecated and to be
-// removed in the future.
-def kEmitRawAttributesFolder : EmitFolderBase, Deprecated<
-  "'useFoldAPI' of 'kEmitRawAttributesFolder' (default) has been deprecated "
-  # "and is pending removal. Please switch to 'kEmitFoldAdaptorFolder'. See "
-  # "https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618"
-> {}
-// Generate 'fold' method with 'FoldAdaptor' parameter.
-def kEmitFoldAdaptorFolder : EmitFolderBase {}
-
 class Dialect {
   // The name of the dialect.
   string name = ?;
@@ -116,9 +103,6 @@ class Dialect {
 
   // If this dialect can be extended at runtime with new operations or types.
   bit isExtensible = 0;
-
-  // Fold API to use for operations in this dialect.
-  EmitFolderBase useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 #endif // DIALECTBASE_TD

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index f7d8436dcc1c9..fe2bdd5386439 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1688,33 +1688,17 @@ class Op : public OpState, public Traits<ConcreteType>... {
   /// Trait to check if T provides a 'fold' method for a single result op.
   template <typename T, typename... Args>
   using has_single_result_fold_t =
-      decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
+      decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
   template <typename T>
   constexpr static bool has_single_result_fold_v =
       llvm::is_detected<has_single_result_fold_t, T>::value;
   /// Trait to check if T provides a general 'fold' method.
   template <typename T, typename... Args>
   using has_fold_t = decltype(std::declval<T>().fold(
-      std::declval<ArrayRef<Attribute>>(),
+      std::declval<typename T::FoldAdaptor>(),
       std::declval<SmallVectorImpl<OpFoldResult> &>()));
   template <typename T>
   constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
-  /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
-  /// single result op.
-  template <typename T, typename... Args>
-  using has_fold_adaptor_single_result_fold_t =
-      decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
-  template <class T>
-  constexpr static bool has_fold_adaptor_single_result_v =
-      llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
-  /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
-  template <typename T, typename... Args>
-  using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
-      std::declval<typename T::FoldAdaptor>(),
-      std::declval<SmallVectorImpl<OpFoldResult> &>()));
-  template <class T>
-  constexpr static bool has_fold_adaptor_v =
-      llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
 
   /// Trait to check if T provides a 'print' method.
   template <typename T, typename... Args>
@@ -1764,14 +1748,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
     // If the operation is single result and defines a `fold` method.
     if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
                                   Traits<ConcreteType>...>::value &&
-                  (has_single_result_fold_v<ConcreteType> ||
-                   has_fold_adaptor_single_result_v<ConcreteType>))
+                  has_single_result_fold_v<ConcreteType>)
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldSingleResultHook<ConcreteType>(op, operands, results);
       };
     // The operation is not single result and defines a `fold` method.
-    if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
+    if constexpr (has_fold_v<ConcreteType>)
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldHook<ConcreteType>(op, operands, results);
@@ -1790,12 +1773,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static LogicalResult
   foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
                        SmallVectorImpl<OpFoldResult> &results) {
-    OpFoldResult result;
-    if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
-      result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
-          operands, op->getAttrDictionary(), op->getRegions()));
-    else
-      result = cast<ConcreteOpT>(op).fold(operands);
+    OpFoldResult result =
+        cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+            operands, op->getAttrDictionary(), op->getRegions()));
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
@@ -1812,15 +1792,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
   template <typename ConcreteOpT>
   static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
                                 SmallVectorImpl<OpFoldResult> &results) {
-    auto result = LogicalResult::failure();
-    if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
-      result = cast<ConcreteOpT>(op).fold(
-          typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
-                                            op->getRegions()),
-          results);
-    } else {
-      result = cast<ConcreteOpT>(op).fold(operands, results);
-    }
+    LogicalResult result = cast<ConcreteOpT>(op).fold(
+        typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+                                          op->getRegions()),
+        results);
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.

diff  --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 8fd519fa63681..d85342c742e2e 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -86,15 +86,6 @@ class Dialect {
   /// operations or types.
   bool isExtensible() const;
 
-  enum class FolderAPI {
-    RawAttributes = 0, /// fold method with ArrayRef<Attribute>.
-    FolderAdaptor = 1, /// fold method with the operation's FoldAdaptor.
-  };
-
-  /// Returns the folder API that should be emitted for operations in this
-  /// dialect.
-  FolderAPI getFolderAPI() const;
-
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;

diff  --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index e41e2e7209ed9..ec2c31dab440c 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -103,21 +103,6 @@ bool Dialect::isExtensible() const {
   return def->getValueAsBit("isExtensible");
 }
 
-Dialect::FolderAPI Dialect::getFolderAPI() const {
-  llvm::Record *value = def->getValueAsDef("useFoldAPI");
-  auto converted =
-      llvm::StringSwitch<std::optional<Dialect::FolderAPI>>(value->getName())
-          .Case("kEmitRawAttributesFolder", FolderAPI::RawAttributes)
-          .Case("kEmitFoldAdaptorFolder", FolderAPI::FolderAdaptor)
-          .Default(std::nullopt);
-
-  if (!converted)
-    llvm::PrintFatalError(def->getLoc(),
-                          "Invalid value for dialect field `useFoldAPI`");
-
-  return *converted;
-}
-
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }

diff  --git a/mlir/test/mlir-tblgen/has-fold-invalid-values.td b/mlir/test/mlir-tblgen/has-fold-invalid-values.td
deleted file mode 100644
index f61284dd3822d..0000000000000
--- a/mlir/test/mlir-tblgen/has-fold-invalid-values.td
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-def Bad : EmitFolderBase;
-
-def Test_Dialect : Dialect {
-  let name = "test";
-  let cppNamespace = "NS";
-  let useFoldAPI = Bad;
-}
-
-def InvalidValue_Op : Op<Test_Dialect, "invalid_op"> {
-  let hasFolder = 1;
-}
-
-// CHECK: Invalid value for dialect field `useFoldAPI`

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 3a486a12bb015..ed54ef2238019 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -12,7 +12,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def Test_Dialect : Dialect {
   let name = "test";
   let cppNamespace = "NS";
-  let useFoldAPI = kEmitRawAttributesFolder;
 }
 class NS_Op<string mnemonic, list<Trait> traits> :
     Op<Test_Dialect, mnemonic, traits>;
@@ -120,7 +119,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   void print(::mlir::OpAsmPrinter &p);
 // CHECK:   ::mlir::LogicalResult verifyInvariants();
 // CHECK:   static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
-// CHECK:   ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
+// CHECK:   ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
 // CHECK:   // Display a graph for debugging purposes.
 // CHECK:   void displayGraph();
 // CHECK: };
@@ -322,12 +321,7 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
-def TestWithNewFold_Dialect : Dialect {
-  let name = "test";
-  let cppNamespace = "::mlir::testWithFold";
-}
-
-def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor_fold", []> {
+def NS_MOp : NS_Op<"op_with_single_result_and_fold_adaptor_fold", []> {
   let results = (outs AnyType:$res);
 
   let hasFolder = 1;
@@ -336,15 +330,6 @@ def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor
 // CHECK-LABEL: class MOp :
 // CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
 
-def NS_NOp : Op<TestWithNewFold_Dialect, "op_with_multiple_results_and_fold_adaptor_fold", []> {
-  let results = (outs AnyType:$res1, AnyType:$res2);
-
-  let hasFolder = 1;
-}
-
-// CHECK-LABEL: class NOp :
-// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
-
 // Test that type defs have the proper namespaces when used as a constraint.
 // ---
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ef430bfe1b054..5a7e3a01b6780 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2342,12 +2342,8 @@ void OpEmitter::genFolderDecls() {
   if (!op.hasFolder())
     return;
 
-  Dialect::FolderAPI folderApi = op.getDialect().getFolderAPI();
   SmallVector<MethodParameter> paramList;
-  if (folderApi == Dialect::FolderAPI::RawAttributes)
-    paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
-  else
-    paramList.emplace_back("FoldAdaptor", "adaptor");
+  paramList.emplace_back("FoldAdaptor", "adaptor");
 
   StringRef retType;
   bool hasSingleResult =


        


More information about the Mlir-commits mailing list