[llvm] [mlir] [NFCI][LLVM/MLIR] Adopt `TrailingObjects` convenience API (PR #138554)

Rahul Joshi via llvm-commits llvm-commits at lists.llvm.org
Sat May 10 19:39:00 PDT 2025


https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/138554

>From b966795256d4cbf9acfe6730f5ccc20e18dc13a6 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 1 May 2025 00:30:52 -0700
Subject: [PATCH] [NFCI][LLVM/MLIR] Adopt `TrailingObjects` convienence API

Adopt `TrailingObjects` convienence API that was added in
https://github.com/llvm/llvm-project/pull/138970 in LLVM and MLIR code.
---
 llvm/include/llvm/DebugInfo/BTF/BTF.h      |  6 +--
 llvm/include/llvm/IR/DataLayout.h          | 12 +++--
 llvm/include/llvm/TableGen/Record.h        | 22 ++++-----
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp  |  4 +-
 llvm/lib/IR/AttributeImpl.h                | 27 ++++-------
 llvm/lib/IR/Attributes.cpp                 |  4 +-
 llvm/lib/Support/TrieRawHashMap.cpp        |  4 +-
 llvm/lib/TableGen/Record.cpp               | 12 ++---
 llvm/lib/Transforms/IPO/LowerTypeTests.cpp | 12 ++---
 mlir/include/mlir/IR/Operation.h           |  5 +--
 mlir/include/mlir/Tools/PDLL/AST/Nodes.h   | 52 ++++++++++------------
 mlir/lib/IR/AffineMapDetail.h              |  8 ++--
 mlir/lib/IR/Location.cpp                   | 11 ++---
 mlir/lib/IR/TypeDetail.h                   |  9 ++--
 14 files changed, 81 insertions(+), 107 deletions(-)

diff --git a/llvm/include/llvm/DebugInfo/BTF/BTF.h b/llvm/include/llvm/DebugInfo/BTF/BTF.h
index d88af2ff30bdf..bd666e4202985 100644
--- a/llvm/include/llvm/DebugInfo/BTF/BTF.h
+++ b/llvm/include/llvm/DebugInfo/BTF/BTF.h
@@ -304,14 +304,12 @@ enum PatchableRelocKind : uint32_t {
 // For CommonType sub-types that are followed by a single entry of
 // some type in the binary format.
 #define BTF_DEFINE_TAIL(Type, Accessor)                                        \
-  const Type &Accessor() const { return *getTrailingObjects<Type>(); }
+  const Type &Accessor() const { return *getTrailingObjects(); }
 
 // For CommonType sub-types that are followed by CommonType::getVlen()
 // number of entries of some type in the binary format.
 #define BTF_DEFINE_TAIL_ARR(Type, Accessor)                                    \
-  ArrayRef<Type> Accessor() const {                                            \
-    return ArrayRef<Type>(getTrailingObjects<Type>(), getVlen());              \
-  }
+  ArrayRef<Type> Accessor() const { return getTrailingObjects(getVlen()); }
 
 struct ArrayType final : CommonType,
                          private TrailingObjects<ArrayType, BTFArray> {
diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h
index 2ad080e6d0cd2..d83fe1299237b 100644
--- a/llvm/include/llvm/IR/DataLayout.h
+++ b/llvm/include/llvm/IR/DataLayout.h
@@ -564,7 +564,9 @@ inline LLVMTargetDataRef wrap(const DataLayout *P) {
 
 /// Used to lazily calculate structure layout information for a target machine,
 /// based on the DataLayout structure.
-class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
+class StructLayout final : private TrailingObjects<StructLayout, TypeSize> {
+  friend TrailingObjects;
+
   TypeSize StructSize;
   Align StructAlignment;
   unsigned IsPadded : 1;
@@ -586,11 +588,11 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   unsigned getElementContainingOffset(uint64_t FixedOffset) const;
 
   MutableArrayRef<TypeSize> getMemberOffsets() {
-    return llvm::MutableArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   ArrayRef<TypeSize> getMemberOffsets() const {
-    return llvm::ArrayRef(getTrailingObjects<TypeSize>(), NumElements);
+    return getTrailingObjects(NumElements);
   }
 
   TypeSize getElementOffset(unsigned Idx) const {
@@ -606,10 +608,6 @@ class StructLayout final : public TrailingObjects<StructLayout, TypeSize> {
   friend class DataLayout; // Only DataLayout can create this class
 
   StructLayout(StructType *ST, const DataLayout &DL);
-
-  size_t numTrailingObjects(OverloadToken<TypeSize>) const {
-    return NumElements;
-  }
 };
 
 // The implementation of this method is provided inline as it is particularly
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 982cc255553a2..687980cf5e0e4 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -258,7 +258,7 @@ class RecordRecTy final : public RecTy,
   void Profile(FoldingSetNodeID &ID) const;
 
   ArrayRef<const Record *> getClasses() const {
-    return ArrayRef(getTrailingObjects<const Record *>(), NumClasses);
+    return getTrailingObjects(NumClasses);
   }
 
   using const_record_iterator = const Record *const *;
@@ -632,9 +632,7 @@ class BitsInit final : public TypedInit,
 
   const Init *resolveReferences(Resolver &R) const override;
 
-  ArrayRef<const Init *> getBits() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumBits);
-  }
+  ArrayRef<const Init *> getBits() const { return getTrailingObjects(NumBits); }
 
   const Init *getBit(unsigned Bit) const override { return getBits()[Bit]; }
 };
@@ -783,7 +781,7 @@ class ListInit final : public TypedInit,
   void Profile(FoldingSetNodeID &ID) const;
 
   ArrayRef<const Init *> getValues() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumValues);
+    return ArrayRef(getTrailingObjects(), NumValues);
   }
 
   const Init *getElement(unsigned Index) const { return getValues()[Index]; }
@@ -1026,10 +1024,6 @@ class CondOpInit final : public TypedInit,
   CondOpInit(ArrayRef<const Init *> Conds, ArrayRef<const Init *> Values,
              const RecTy *Type);
 
-  size_t numTrailingObjects(OverloadToken<Init *>) const {
-    return 2*NumConds;
-  }
-
 public:
   CondOpInit(const CondOpInit &) = delete;
   CondOpInit &operator=(const CondOpInit &) = delete;
@@ -1053,11 +1047,11 @@ class CondOpInit final : public TypedInit,
   const Init *getVal(unsigned Num) const { return getVals()[Num]; }
 
   ArrayRef<const Init *> getConds() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumConds);
+    return getTrailingObjects(NumConds);
   }
 
   ArrayRef<const Init *> getVals() const {
-    return ArrayRef(getTrailingObjects<const Init *>() + NumConds, NumConds);
+    return ArrayRef(getTrailingObjects() + NumConds, NumConds);
   }
 
   const Init *Fold(const Record *CurRec) const;
@@ -1375,7 +1369,7 @@ class VarDefInit final
   bool           args_empty() const { return NumArgs == 0; }
 
   ArrayRef<const ArgumentInit *> args() const {
-    return ArrayRef(getTrailingObjects<const ArgumentInit *>(), NumArgs);
+    return getTrailingObjects(NumArgs);
   }
 
   const Init *getBit(unsigned Bit) const override {
@@ -1488,11 +1482,11 @@ class DagInit final
   }
 
   ArrayRef<const Init *> getArgs() const {
-    return ArrayRef(getTrailingObjects<const Init *>(), NumArgs);
+    return getTrailingObjects<const Init *>(NumArgs);
   }
 
   ArrayRef<const StringInit *> getArgNames() const {
-    return ArrayRef(getTrailingObjects<const StringInit *>(), NumArgs);
+    return getTrailingObjects<const StringInit *>(NumArgs);
   }
 
   const Init *resolveReferences(Resolver &R) const override;
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 1e07f060d72cb..64f963814e1cc 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -542,7 +542,7 @@ class BitcodeConstant final : public Value,
       : Value(Ty, SubclassID), Opcode(Info.Opcode), Flags(Info.Flags),
         NumOperands(OpIDs.size()), BlockAddressBB(Info.BlockAddressBB),
         SrcElemTy(Info.SrcElemTy), InRange(Info.InRange) {
-    llvm::uninitialized_copy(OpIDs, getTrailingObjects<unsigned>());
+    llvm::uninitialized_copy(OpIDs, getTrailingObjects());
   }
 
   BitcodeConstant &operator=(const BitcodeConstant &) = delete;
@@ -559,7 +559,7 @@ class BitcodeConstant final : public Value,
   static bool classof(const Value *V) { return V->getValueID() == SubclassID; }
 
   ArrayRef<unsigned> getOperandIDs() const {
-    return ArrayRef(getTrailingObjects<unsigned>(), NumOperands);
+    return ArrayRef(getTrailingObjects(), NumOperands);
   }
 
   std::optional<ConstantRange> getInRange() const {
diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h
index 98d1bad7680ab..707c8205ee1f9 100644
--- a/llvm/lib/IR/AttributeImpl.h
+++ b/llvm/lib/IR/AttributeImpl.h
@@ -195,15 +195,12 @@ class StringAttributeImpl final
 
   unsigned KindSize;
   unsigned ValSize;
-  size_t numTrailingObjects(OverloadToken<char>) const {
-    return KindSize + 1 + ValSize + 1;
-  }
 
 public:
   StringAttributeImpl(StringRef Kind, StringRef Val = StringRef())
       : AttributeImpl(StringAttrEntry), KindSize(Kind.size()),
         ValSize(Val.size()) {
-    char *TrailingString = getTrailingObjects<char>();
+    char *TrailingString = getTrailingObjects();
     // Some users rely on zero-termination.
     llvm::copy(Kind, TrailingString);
     TrailingString[KindSize] = '\0';
@@ -212,10 +209,10 @@ class StringAttributeImpl final
   }
 
   StringRef getStringKind() const {
-    return StringRef(getTrailingObjects<char>(), KindSize);
+    return StringRef(getTrailingObjects(), KindSize);
   }
   StringRef getStringValue() const {
-    return StringRef(getTrailingObjects<char>() + KindSize + 1, ValSize);
+    return StringRef(getTrailingObjects() + KindSize + 1, ValSize);
   }
 
   static size_t totalSizeToAlloc(StringRef Kind, StringRef Val) {
@@ -250,25 +247,22 @@ class ConstantRangeListAttributeImpl final
   friend TrailingObjects;
 
   unsigned Size;
-  size_t numTrailingObjects(OverloadToken<ConstantRange>) const { return Size; }
 
 public:
   ConstantRangeListAttributeImpl(Attribute::AttrKind Kind,
                                  ArrayRef<ConstantRange> Val)
       : EnumAttributeImpl(ConstantRangeListAttrEntry, Kind), Size(Val.size()) {
     assert(Size > 0);
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
-    llvm::uninitialized_copy(Val, TrailingCR);
+    llvm::uninitialized_copy(Val, getTrailingObjects());
   }
 
   ~ConstantRangeListAttributeImpl() {
-    ConstantRange *TrailingCR = getTrailingObjects<ConstantRange>();
-    for (unsigned I = 0; I != Size; ++I)
-      TrailingCR[I].~ConstantRange();
+    for (ConstantRange &CR : getTrailingObjects(Size))
+      CR.~ConstantRange();
   }
 
   ArrayRef<ConstantRange> getConstantRangeListValue() const {
-    return ArrayRef(getTrailingObjects<ConstantRange>(), Size);
+    return getTrailingObjects(Size);
   }
 
   static size_t totalSizeToAlloc(ArrayRef<ConstantRange> Val) {
@@ -353,7 +347,7 @@ class AttributeSetNode final
 
   using iterator = const Attribute *;
 
-  iterator begin() const { return getTrailingObjects<Attribute>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrs; }
 
   void Profile(FoldingSetNodeID &ID) const {
@@ -383,9 +377,6 @@ class AttributeListImpl final
   /// Union of enum attributes available at any index.
   AttributeBitSet AvailableSomewhereAttrs;
 
-  // Helper fn for TrailingObjects class.
-  size_t numTrailingObjects(OverloadToken<AttributeSet>) { return NumAttrSets; }
-
 public:
   AttributeListImpl(ArrayRef<AttributeSet> Sets);
 
@@ -407,7 +398,7 @@ class AttributeListImpl final
 
   using iterator = const AttributeSet *;
 
-  iterator begin() const { return getTrailingObjects<AttributeSet>(); }
+  iterator begin() const { return getTrailingObjects(); }
   iterator end() const { return begin() + NumAttrSets; }
 
   void Profile(FoldingSetNodeID &ID) const;
diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp
index 33ac8bfaf4e7c..5b0ceb34381a9 100644
--- a/llvm/lib/IR/Attributes.cpp
+++ b/llvm/lib/IR/Attributes.cpp
@@ -1237,7 +1237,7 @@ LLVM_DUMP_METHOD void AttributeSet::dump() const {
 AttributeSetNode::AttributeSetNode(ArrayRef<Attribute> Attrs)
     : NumAttrs(Attrs.size()) {
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Attrs, getTrailingObjects<Attribute>());
+  llvm::copy(Attrs, getTrailingObjects());
 
   for (const auto &I : *this) {
     if (I.isStringAttribute())
@@ -1423,7 +1423,7 @@ AttributeListImpl::AttributeListImpl(ArrayRef<AttributeSet> Sets)
   assert(!Sets.empty() && "pointless AttributeListImpl");
 
   // There's memory after the node where we can store the entries in.
-  llvm::copy(Sets, getTrailingObjects<AttributeSet>());
+  llvm::copy(Sets, getTrailingObjects());
 
   // Initialize AvailableFunctionAttrs and AvailableSomewhereAttrs
   // summary bitsets.
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 11d79a62d011d..bb779fe87ae62 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -62,7 +62,7 @@ class TrieSubtrie final
 public:
   using Slot = LazyAtomicPointer<TrieNode>;
 
-  Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
+  Slot &get(size_t I) { return getTrailingObjects()[I]; }
   TrieNode *load(size_t I) { return get(I).load(); }
 
   unsigned size() const { return Size; }
@@ -190,7 +190,7 @@ class ThreadSafeTrieRawHashMapBase::ImplType final
   }
 
   // Get the root which is the trailing object.
-  TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
+  TrieSubtrie *getRoot() { return getTrailingObjects(); }
 
   static void *operator new(size_t Size) { return ::operator new(Size); }
   void operator delete(void *Ptr) { ::operator delete(Ptr); }
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index f3d54e6083e48..e09ea4902fa5d 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -240,7 +240,7 @@ static void ProfileRecordRecTy(FoldingSetNodeID &ID,
 
 RecordRecTy::RecordRecTy(RecordKeeper &RK, ArrayRef<const Record *> Classes)
     : RecTy(RecordRecTyKind, RK), NumClasses(Classes.size()) {
-  llvm::uninitialized_copy(Classes, getTrailingObjects<const Record *>());
+  llvm::uninitialized_copy(Classes, getTrailingObjects());
 }
 
 const RecordRecTy *RecordRecTy::get(RecordKeeper &RK,
@@ -473,7 +473,7 @@ static void ProfileBitsInit(FoldingSetNodeID &ID,
 BitsInit::BitsInit(RecordKeeper &RK, ArrayRef<const Init *> Bits)
     : TypedInit(IK_BitsInit, BitsRecTy::get(RK, Bits.size())),
       NumBits(Bits.size()) {
-  llvm::uninitialized_copy(Bits, getTrailingObjects<const Init *>());
+  llvm::uninitialized_copy(Bits, getTrailingObjects());
 }
 
 BitsInit *BitsInit::get(RecordKeeper &RK, ArrayRef<const Init *> Bits) {
@@ -493,7 +493,7 @@ BitsInit *BitsInit::get(RecordKeeper &RK, ArrayRef<const Init *> Bits) {
 }
 
 void BitsInit::Profile(FoldingSetNodeID &ID) const {
-  ProfileBitsInit(ID, ArrayRef(getTrailingObjects<const Init *>(), NumBits));
+  ProfileBitsInit(ID, getBits());
 }
 
 const Init *BitsInit::convertInitializerTo(const RecTy *Ty) const {
@@ -706,7 +706,7 @@ static void ProfileListInit(FoldingSetNodeID &ID, ArrayRef<const Init *> Range,
 ListInit::ListInit(ArrayRef<const Init *> Elements, const RecTy *EltTy)
     : TypedInit(IK_ListInit, ListRecTy::get(EltTy)),
       NumValues(Elements.size()) {
-  llvm::uninitialized_copy(Elements, getTrailingObjects<const Init *>());
+  llvm::uninitialized_copy(Elements, getTrailingObjects());
 }
 
 const ListInit *ListInit::get(ArrayRef<const Init *> Elements,
@@ -2432,7 +2432,7 @@ VarDefInit::VarDefInit(SMLoc Loc, const Record *Class,
                        ArrayRef<const ArgumentInit *> Args)
     : TypedInit(IK_VarDefInit, RecordRecTy::get(Class)), Loc(Loc), Class(Class),
       NumArgs(Args.size()) {
-  llvm::uninitialized_copy(Args, getTrailingObjects<const ArgumentInit *>());
+  llvm::uninitialized_copy(Args, getTrailingObjects());
 }
 
 const VarDefInit *VarDefInit::get(SMLoc Loc, const Record *Class,
@@ -2616,7 +2616,7 @@ static void ProfileCondOpInit(FoldingSetNodeID &ID,
 CondOpInit::CondOpInit(ArrayRef<const Init *> Conds,
                        ArrayRef<const Init *> Values, const RecTy *Type)
     : TypedInit(IK_CondOpInit, Type), NumConds(Conds.size()), ValType(Type) {
-  auto *TrailingObjects = getTrailingObjects<const Init *>();
+  const Init **TrailingObjects = getTrailingObjects();
   llvm::uninitialized_copy(Conds, TrailingObjects);
   llvm::uninitialized_copy(Values, TrailingObjects + NumConds);
 }
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index d855647095550..ebabece067db2 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -285,8 +285,6 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
   // module and its jumptable entry needs to be exported to thinlto backends.
   bool IsExported;
 
-  size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
-
 public:
   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
                                   bool IsJumpTableCanonical, bool IsExported,
@@ -297,7 +295,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     GTM->NTypes = Types.size();
     GTM->IsJumpTableCanonical = IsJumpTableCanonical;
     GTM->IsExported = IsExported;
-    llvm::copy(Types, GTM->getTrailingObjects<MDNode *>());
+    llvm::copy(Types, GTM->getTrailingObjects());
     return GTM;
   }
 
@@ -313,9 +311,7 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
     return IsExported;
   }
 
-  ArrayRef<MDNode *> types() const {
-    return ArrayRef(getTrailingObjects<MDNode *>(), NTypes);
-  }
+  ArrayRef<MDNode *> types() const { return getTrailingObjects(NTypes); }
 };
 
 struct ICallBranchFunnel final
@@ -329,13 +325,13 @@ struct ICallBranchFunnel final
     Call->CI = CI;
     Call->UniqueId = UniqueId;
     Call->NTargets = Targets.size();
-    llvm::copy(Targets, Call->getTrailingObjects<GlobalTypeMember *>());
+    llvm::copy(Targets, Call->getTrailingObjects());
     return Call;
   }
 
   CallInst *CI;
   ArrayRef<GlobalTypeMember *> targets() const {
-    return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
+    return getTrailingObjects(NTargets);
   }
 
   unsigned UniqueId;
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 95d944170732e..68ab1527b480a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -679,8 +679,7 @@ class alignas(8) Operation final
     if (numRegions == 0)
       return MutableArrayRef<Region>();
 
-    auto *regions = getTrailingObjects<Region>();
-    return {regions, numRegions};
+    return getTrailingObjects<Region>(numRegions);
   }
 
   /// Returns the region held by this operation at position 'index'.
@@ -694,7 +693,7 @@ class alignas(8) Operation final
   //===--------------------------------------------------------------------===//
 
   MutableArrayRef<BlockOperand> getBlockOperands() {
-    return {getTrailingObjects<BlockOperand>(), numSuccs};
+    return getTrailingObjects<BlockOperand>(numSuccs);
   }
 
   // Successor iteration.
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index f174ac2f476f6..9ad94839890b7 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -183,10 +183,10 @@ class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
 
   /// Return the children of this compound statement.
   MutableArrayRef<Stmt *> getChildren() {
-    return {getTrailingObjects<Stmt *>(), numChildren};
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *> getChildren() const {
-    return const_cast<CompoundStmt *>(this)->getChildren();
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
@@ -275,10 +275,10 @@ class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
 
   /// Return the replacement values of this statement.
   MutableArrayRef<Expr *> getReplExprs() {
-    return {getTrailingObjects<Expr *>(), numReplExprs};
+    return getTrailingObjects(numReplExprs);
   }
   ArrayRef<Expr *> getReplExprs() const {
-    return const_cast<ReplaceStmt *>(this)->getReplExprs();
+    return getTrailingObjects(numReplExprs);
   }
 
 private:
@@ -400,12 +400,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
   Expr *getCallableExpr() const { return callable; }
 
   /// Return the arguments of this call.
-  MutableArrayRef<Expr *> getArguments() {
-    return {getTrailingObjects<Expr *>(), numArgs};
-  }
-  ArrayRef<Expr *> getArguments() const {
-    return const_cast<CallExpr *>(this)->getArguments();
-  }
+  MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
+  ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
 
   /// Returns whether the result of this call is to be negated.
   bool getIsNegated() const { return isNegated; }
@@ -534,10 +530,10 @@ class OperationExpr final
 
   /// Return the operands of this operation.
   MutableArrayRef<Expr *> getOperands() {
-    return {getTrailingObjects<Expr *>(), numOperands};
+    return getTrailingObjects<Expr *>(numOperands);
   }
   ArrayRef<Expr *> getOperands() const {
-    return const_cast<OperationExpr *>(this)->getOperands();
+    return getTrailingObjects<Expr *>(numOperands);
   }
 
   /// Return the result types of this operation.
@@ -550,10 +546,10 @@ class OperationExpr final
 
   /// Return the attributes of this operation.
   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
-    return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
+    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
   }
-  MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
-    return const_cast<OperationExpr *>(this)->getAttributes();
+  ArrayRef<NamedAttributeDecl *> getAttributes() const {
+    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
   }
 
 private:
@@ -594,10 +590,10 @@ class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
 
   /// Return the element expressions of this range.
   MutableArrayRef<Expr *> getElements() {
-    return {getTrailingObjects<Expr *>(), numElements};
+    return getTrailingObjects(numElements);
   }
   ArrayRef<Expr *> getElements() const {
-    return const_cast<RangeExpr *>(this)->getElements();
+    return getTrailingObjects(numElements);
   }
 
   /// Return the range result type of this expression.
@@ -627,10 +623,10 @@ class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
 
   /// Return the element expressions of this tuple.
   MutableArrayRef<Expr *> getElements() {
-    return {getTrailingObjects<Expr *>(), getType().size()};
+    return getTrailingObjects(getType().size());
   }
   ArrayRef<Expr *> getElements() const {
-    return const_cast<TupleExpr *>(this)->getElements();
+    return getTrailingObjects(getType().size());
   }
 
   /// Return the tuple result type of this expression.
@@ -916,10 +912,10 @@ class UserConstraintDecl final
 
   /// Return the input arguments of this constraint.
   MutableArrayRef<VariableDecl *> getInputs() {
-    return {getTrailingObjects<VariableDecl *>(), numInputs};
+    return getTrailingObjects<VariableDecl *>(numInputs);
   }
   ArrayRef<VariableDecl *> getInputs() const {
-    return const_cast<UserConstraintDecl *>(this)->getInputs();
+    return getTrailingObjects<VariableDecl *>(numInputs);
   }
 
   /// Return the explicit native type to use for the given input. Returns
@@ -1126,16 +1122,16 @@ class UserRewriteDecl final
 
   /// Return the input arguments of this rewrite.
   MutableArrayRef<VariableDecl *> getInputs() {
-    return {getTrailingObjects<VariableDecl *>(), numInputs};
+    return getTrailingObjects(numInputs);
   }
   ArrayRef<VariableDecl *> getInputs() const {
-    return const_cast<UserRewriteDecl *>(this)->getInputs();
+    return getTrailingObjects(numInputs);
   }
 
   /// Return the explicit results of the rewrite declaration. May be empty,
   /// even if the rewrite has results (e.g. in the case of inferred results).
   MutableArrayRef<VariableDecl *> getResults() {
-    return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+    return {getTrailingObjects() + numInputs, numResults};
   }
   ArrayRef<VariableDecl *> getResults() const {
     return const_cast<UserRewriteDecl *>(this)->getResults();
@@ -1257,10 +1253,10 @@ class VariableDecl final
 
   /// Return the constraints of this variable.
   MutableArrayRef<ConstraintRef> getConstraints() {
-    return {getTrailingObjects<ConstraintRef>(), numConstraints};
+    return getTrailingObjects(numConstraints);
   }
   ArrayRef<ConstraintRef> getConstraints() const {
-    return const_cast<VariableDecl *>(this)->getConstraints();
+    return getTrailingObjects(numConstraints);
   }
 
   /// Return the initializer expression of this statement, or nullptr if there
@@ -1304,10 +1300,10 @@ class Module final : public Node::NodeBase<Module, Node>,
 
   /// Return the children of this module.
   MutableArrayRef<Decl *> getChildren() {
-    return {getTrailingObjects<Decl *>(), numChildren};
+    return getTrailingObjects(numChildren);
   }
   ArrayRef<Decl *> getChildren() const {
-    return const_cast<Module *>(this)->getChildren();
+    return getTrailingObjects(numChildren);
   }
 
 private:
diff --git a/mlir/lib/IR/AffineMapDetail.h b/mlir/lib/IR/AffineMapDetail.h
index 32c9734f23a36..b306462357a97 100644
--- a/mlir/lib/IR/AffineMapDetail.h
+++ b/mlir/lib/IR/AffineMapDetail.h
@@ -24,7 +24,9 @@ namespace detail {
 
 struct AffineMapStorage final
     : public StorageUniquer::BaseStorage,
-      public llvm::TrailingObjects<AffineMapStorage, AffineExpr> {
+      private llvm::TrailingObjects<AffineMapStorage, AffineExpr> {
+  friend llvm::TrailingObjects<AffineMapStorage, AffineExpr>;
+
   /// The hash key used for uniquing.
   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
 
@@ -36,7 +38,7 @@ struct AffineMapStorage final
 
   /// The affine expressions for this (multi-dimensional) map.
   ArrayRef<AffineExpr> results() const {
-    return {getTrailingObjects<AffineExpr>(), numResults};
+    return getTrailingObjects(numResults);
   }
 
   bool operator==(const KeyTy &key) const {
@@ -56,7 +58,7 @@ struct AffineMapStorage final
     res->numDims = std::get<0>(key);
     res->numSymbols = std::get<1>(key);
     res->numResults = results.size();
-    llvm::uninitialized_copy(results, res->getTrailingObjects<AffineExpr>());
+    llvm::uninitialized_copy(results, res->getTrailingObjects());
     return res;
   }
 };
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 8ae33022be24f..f897546f36ba7 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -34,7 +34,8 @@ using namespace mlir::detail;
 namespace mlir::detail {
 struct FileLineColRangeAttrStorage final
     : public ::mlir::AttributeStorage,
-      public llvm::TrailingObjects<FileLineColRangeAttrStorage, unsigned> {
+      private llvm::TrailingObjects<FileLineColRangeAttrStorage, unsigned> {
+  friend llvm::TrailingObjects<FileLineColRangeAttrStorage, unsigned>;
   using PointerPair = llvm::PointerIntPair<StringAttr, 2>;
   using KeyTy = std::tuple<StringAttr, ::llvm::ArrayRef<unsigned>>;
 
@@ -62,7 +63,7 @@ struct FileLineColRangeAttrStorage final
       result->startLine = elements[0];
       // Copy in the element types into the trailing storage.
       llvm::uninitialized_copy(elements.drop_front(),
-                               result->getTrailingObjects<unsigned>());
+                               result->getTrailingObjects());
     }
     return result;
   }
@@ -74,12 +75,12 @@ struct FileLineColRangeAttrStorage final
     return (filenameAndTrailing.getPointer() == std::get<0>(tblgenKey)) &&
            (size() == std::get<1>(tblgenKey).size()) &&
            (startLine == std::get<1>(tblgenKey)[0]) &&
-           (ArrayRef<unsigned>{getTrailingObjects<unsigned>(), size() - 1} ==
-            ArrayRef<unsigned>{std::get<1>(tblgenKey)}.drop_front());
+           (getTrailingObjects(size() - 1) ==
+            std::get<1>(tblgenKey).drop_front());
   }
 
   unsigned getLineCols(unsigned index) const {
-    return getTrailingObjects<unsigned>()[index - 1];
+    return getTrailingObjects()[index - 1];
   }
 
   unsigned getStartLine() const { return startLine; }
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 19f3690c3d2dc..0e952d5c14c7e 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -102,7 +102,8 @@ struct FunctionTypeStorage : public TypeStorage {
 /// A type representing a collection of other types.
 struct TupleTypeStorage final
     : public TypeStorage,
-      public llvm::TrailingObjects<TupleTypeStorage, Type> {
+      private llvm::TrailingObjects<TupleTypeStorage, Type> {
+  friend llvm::TrailingObjects<TupleTypeStorage, Type>;
   using KeyTy = TypeRange;
 
   TupleTypeStorage(unsigned numTypes) : numElements(numTypes) {}
@@ -116,7 +117,7 @@ struct TupleTypeStorage final
     auto *result = ::new (rawMem) TupleTypeStorage(key.size());
 
     // Copy in the element types into the trailing storage.
-    llvm::uninitialized_copy(key, result->getTrailingObjects<Type>());
+    llvm::uninitialized_copy(key, result->getTrailingObjects());
     return result;
   }
 
@@ -126,9 +127,7 @@ struct TupleTypeStorage final
   unsigned size() const { return numElements; }
 
   /// Return the held types.
-  ArrayRef<Type> getTypes() const {
-    return {getTrailingObjects<Type>(), size()};
-  }
+  ArrayRef<Type> getTypes() const { return getTrailingObjects(size()); }
 
   KeyTy getAsKey() const { return getTypes(); }
 



More information about the llvm-commits mailing list