[Mlir-commits] [mlir] 920daed - [mlir][SubElements] Remove the ability to override implementations

River Riddle llvmlistbot at llvm.org
Mon Jan 30 22:55:20 PST 2023


Author: River Riddle
Date: 2023-01-30T22:42:26-08:00
New Revision: 920daed78398efa8ed979de85564454e60d5cc3e

URL: https://github.com/llvm/llvm-project/commit/920daed78398efa8ed979de85564454e60d5cc3e
DIFF: https://github.com/llvm/llvm-project/commit/920daed78398efa8ed979de85564454e60d5cc3e.diff

LOG: [mlir][SubElements] Remove the ability to override implementations

It's much cleaner and simpler to drive wacky configs via the
AttrTypeSubElementHandler interface, instead of override.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/IR/AttrTypeSubElements.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
    mlir/test/lib/Dialect/Test/TestTypes.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a44448e18b45c..9ae0ba6365dfa 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,6 +180,7 @@ class LLVMStructType
                               StringRef, bool);
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               ArrayRef<Type> types, bool);
+  using Base::verify;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
@@ -197,11 +198,6 @@ class LLVMStructType
 
   LogicalResult verifyEntries(DataLayoutEntryListRef entries,
                               Location loc) const;
-
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const;
-  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
index 23b8d533b7cd2..fe3f4cd24effd 100644
--- a/mlir/include/mlir/IR/AttrTypeSubElements.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -459,21 +459,29 @@ auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
 
       // Otherwise, we need to replace any necessary sub-elements.
     } else {
+      // Functor used to build the replacement on success.
+      auto buildReplacement = [&](auto newKey, MLIRContext *ctx) {
+        if constexpr (is_tuple<decltype(key)>::value) {
+          return std::apply(
+              [&](auto &&...params) {
+                return constructSubElementReplacement<T>(
+                    ctx, std::forward<decltype(params)>(params)...);
+              },
+              newKey);
+        } else {
+          return constructSubElementReplacement<T>(ctx, newKey);
+        }
+      };
+
       AttrSubElementReplacements attrRepls(replAttrs);
       TypeSubElementReplacements typeRepls(replTypes);
       auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
           key, attrRepls, typeRepls);
-      if constexpr (is_tuple<decltype(key)>::value) {
-        return std::apply(
-            [&](auto &&...params) {
-              return constructSubElementReplacement<T>(
-                  derived.getContext(),
-                  std::forward<decltype(params)>(params)...);
-            },
-            newKey);
-      } else {
-        return constructSubElementReplacement<T>(derived.getContext(), newKey);
-      }
+      MLIRContext *ctx = derived.getContext();
+      if constexpr (std::is_convertible_v<decltype(newKey), LogicalResult>)
+        return succeeded(newKey) ? buildReplacement(*newKey, ctx) : nullptr;
+      else
+        return buildReplacement(newKey, ctx);
     }
   } else {
     return derived;

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 061bd67e10904..333618bc72455 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -127,38 +127,13 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     };
   }
 
-  /// Walk all of the immediately nested sub-attributes and sub-types. This
-  /// method does not recurse into sub elements.
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const {
-    ::mlir::detail::walkImmediateSubElementsImpl(
-        *static_cast<const ConcreteT *>(this), walkAttrsFn, walkTypesFn);
-  }
-
-  /// Replace the immediately nested sub-attributes and sub-types with those
-  /// provided. The order of the provided elements is derived from the order of
-  /// the elements returned by the callbacks of `walkImmediateSubElements`. The
-  /// element at index 0 would replace the very first attribute given by
-  /// `walkImmediateSubElements`. On success, the new instance with the values
-  /// replaced is returned. If replacement fails, nullptr is returned.
-  ///
-  /// Note that replacing the sub-elements of mutable types or attributes is
-  /// not currently supported by the interface. If an implementing type or
-  /// attribute is mutable, it should return `nullptr` if it has no mechanism
-  /// for replacing sub elements.
-  auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const {
-    return ::mlir::detail::replaceImmediateSubElementsImpl(
-        *static_cast<const ConcreteT *>(this), replAttrs, replTypes);
-  }
-
   /// Returns a function that walks immediate sub elements of a given instance
   /// of the storage user.
   static auto getWalkImmediateSubElementsFn() {
     return [](auto instance, function_ref<void(Attribute)> walkAttrsFn,
               function_ref<void(Type)> walkTypesFn) {
-      cast<ConcreteT>(instance).walkImmediateSubElements(walkAttrsFn,
-                                                         walkTypesFn);
+      ::mlir::detail::walkImmediateSubElementsImpl(cast<ConcreteT>(instance),
+                                                   walkAttrsFn, walkTypesFn);
     };
   }
 
@@ -167,8 +142,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static auto getReplaceImmediateSubElementsFn() {
     return [](auto instance, ArrayRef<Attribute> replAttrs,
               ArrayRef<Type> replTypes) {
-      return cast<ConcreteT>(instance).replaceImmediateSubElements(replAttrs,
-                                                                   replTypes);
+      return ::mlir::detail::replaceImmediateSubElementsImpl(
+          cast<ConcreteT>(instance), replAttrs, replTypes);
     };
   }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ce28aa09b8214..2e6bf86022fa3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -643,23 +643,6 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
   return mlir::success();
 }
 
-void LLVMStructType::walkImmediateSubElements(
-    function_ref<void(Attribute)> walkAttrsFn,
-    function_ref<void(Type)> walkTypesFn) const {
-  for (Type type : getBody())
-    walkTypesFn(type);
-}
-
-Type LLVMStructType::replaceImmediateSubElements(
-    ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
-  if (isIdentified()) {
-    // TODO: It's not clear how we support replacing sub-elements of mutable
-    // types.
-    return nullptr;
-  }
-  return getLiteral(getContext(), replTypes, isPacked());
-}
-
 //===----------------------------------------------------------------------===//
 // Vector types.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index d30e94c1eca1d..2040d0a06b2e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -69,8 +69,9 @@ struct LLVMStructTypeStorage : public TypeStorage {
   class Key {
   public:
     /// Constructs a key for an identified struct.
-    Key(StringRef name, bool opaque)
-        : name(name), identified(true), packed(false), opaque(opaque) {}
+    Key(StringRef name, bool opaque, ArrayRef<Type> types = std::nullopt)
+        : types(types), name(name), identified(true), packed(false),
+          opaque(opaque) {}
     /// Constructs a key for a literal struct.
     Key(ArrayRef<Type> types, bool packed)
         : types(types), identified(false), packed(packed), opaque(false) {}
@@ -102,6 +103,13 @@ struct LLVMStructTypeStorage : public TypeStorage {
       return types;
     }
 
+    /// Returns the list of type contained in an identified struct.
+    ArrayRef<Type> getIdentifiedStructBody() const {
+      assert(isIdentified() &&
+             "requested struct body on a non-identified struct");
+      return types;
+    }
+
     /// Returns the hash value of the key. This combines various flags into a
     /// single value: the identified flag sets the first bit, and the packedness
     /// flag sets the second bit. Opacity bit is only used for construction and
@@ -220,7 +228,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
   }
 
   /// Hook into the type uniquing infrastructure.
-  bool operator==(const KeyTy &other) const { return getKey() == other; };
+  bool operator==(const KeyTy &other) const { return getAsKey() == other; };
   static llvm::hash_code hashKey(const KeyTy &key) { return key.hashValue(); }
   static LLVMStructTypeStorage *construct(TypeStorageAllocator &allocator,
                                           const KeyTy &key) {
@@ -251,6 +259,13 @@ struct LLVMStructTypeStorage : public TypeStorage {
     return success();
   }
 
+  /// Returns the key for the current storage.
+  Key getAsKey() const {
+    if (isIdentified())
+      return Key(getIdentifier(), isOpaque(), getIdentifiedStructBody());
+    return Key(getTypeList(), isPacked());
+  }
+
 private:
   /// Returns the number of elements in the key.
   unsigned keySize() const {
@@ -271,13 +286,6 @@ struct LLVMStructTypeStorage : public TypeStorage {
     llvm::Bitfield::set<MutableSize>(identifiedBodySizeAndFlags, value);
   }
 
-  /// Returns the key for the current storage.
-  Key getKey() const {
-    if (isIdentified())
-      return Key(getIdentifier(), isOpaque());
-    return Key(getTypeList(), isPacked());
-  }
-
   /// Bitfield elements for `keyAndSizeFlags`:
   ///   - bit 0: identified key flag;
   ///   - bit 1: packed key flag;
@@ -320,7 +328,35 @@ struct LLVMStructTypeStorage : public TypeStorage {
   /// mutable flags. Must only be used through the Mutable* bitfields.
   unsigned identifiedBodySizeAndFlags = 0;
 };
+} // end namespace detail
+} // end namespace LLVM
+
+/// Allow walking and replacing the subelements of a LLVMStructTypeStorage key.
+template <>
+struct AttrTypeSubElementHandler<LLVM::detail::LLVMStructTypeStorage::Key> {
+  static void walk(const LLVM::detail::LLVMStructTypeStorage::Key &param,
+                   AttrTypeImmediateSubElementWalker &walker) {
+    if (param.isIdentified())
+      walker.walkRange(param.getIdentifiedStructBody());
+    else
+      walker.walkRange(param.getTypeList());
+  }
+  static FailureOr<LLVM::detail::LLVMStructTypeStorage::Key>
+  replace(const LLVM::detail::LLVMStructTypeStorage::Key &param,
+          AttrSubElementReplacements &attrRepls,
+          TypeSubElementReplacements &typeRepls) {
+    // TODO: It's not clear how we support replacing sub-elements of mutable
+    // types.
+    if (param.isIdentified())
+      return failure();
+
+    return LLVM::detail::LLVMStructTypeStorage::Key(
+        typeRepls.take_front(param.getTypeList().size()), param.isPacked());
+  }
+};
 
+namespace LLVM {
+namespace detail {
 //===----------------------------------------------------------------------===//
 // LLVMTypeAndSizeStorage.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index c6c8f58ea552e..c7d169d020d56 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -146,18 +146,6 @@ class TestRecursiveType
 
   /// Name/key getter.
   ::llvm::StringRef getName() { return getImpl()->name; }
-
-  void walkImmediateSubElements(
-      ::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
-      ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
-    walkTypesFn(getBody());
-  }
-  Type replaceImmediateSubElements(llvm::ArrayRef<mlir::Attribute> replAttrs,
-                                   llvm::ArrayRef<mlir::Type> replTypes) const {
-    // TODO: It's not clear how we support replacing sub-elements of mutable
-    // types.
-    return nullptr;
-  }
 };
 
 } // namespace test


        


More information about the Mlir-commits mailing list