[llvm] Correctly set pointer bit for aggregate values in SelectionDAGBuilder to fix CCIfPtr (PR #70554)

Camil Staps via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 01:39:40 PST 2024


https://github.com/camilstaps updated https://github.com/llvm/llvm-project/pull/70554

>From 9a5e0e837aefd365b64d4793511f351ecc873cb3 Mon Sep 17 00:00:00 2001
From: Camil Staps <info at camilstaps.nl>
Date: Fri, 3 May 2024 12:54:21 +0200
Subject: [PATCH] Correctly set pointer bit for aggregate values to fix CCIfPtr

---
 llvm/include/llvm/CodeGen/Analysis.h          | 13 +++
 llvm/lib/CodeGen/Analysis.cpp                 | 51 ++++++++++++
 llvm/lib/CodeGen/SelectionDAG/FastISel.cpp    | 11 ++-
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 79 ++++++++++---------
 llvm/lib/CodeGen/TargetLoweringBase.cpp       | 16 ++--
 5 files changed, 123 insertions(+), 47 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h
index 362cc30bbd06a1..f32fb9f2eaec35 100644
--- a/llvm/include/llvm/CodeGen/Analysis.h
+++ b/llvm/include/llvm/CodeGen/Analysis.h
@@ -55,6 +55,19 @@ inline unsigned ComputeLinearIndex(Type *Ty,
   return ComputeLinearIndex(Ty, Indices.begin(), Indices.end(), CurIndex);
 }
 
+/// ComputeValueTypes - Given an LLVM IR type, compute a sequence of
+/// Types that represent all the individual underlying
+/// non-aggregate types that comprise it.
+///
+/// If Offsets is non-null, it points to a vector to be filled in
+/// with the in-memory offsets of each of the individual values.
+///
+void ComputeValueTypes(const TargetLowering &TLI, const DataLayout &DL,
+                       Type *Ty,
+                       SmallVectorImpl<Type *> &ValueTys,
+                       SmallVectorImpl<TypeSize> *Offsets = nullptr,
+                       TypeSize StartingOffset = TypeSize::getZero());
+
 /// ComputeValueVTs - Given an LLVM IR type, compute a sequence of
 /// EVTs that represent all the individual underlying
 /// non-aggregate types that comprise it.
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index e7b9417de8c9f7..353fcd45ab3b53 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -69,6 +69,57 @@ unsigned llvm::ComputeLinearIndex(Type *Ty,
   return CurIndex + 1;
 }
 
+/// ComputeValueTypes - Given an LLVM IR type, compute a sequence of
+/// Types that represent all the individual underlying
+/// non-aggregate types that comprise it.
+///
+/// If Offsets is non-null, it points to a vector to be filled in
+/// with the in-memory offsets of each of the individual values.
+///
+void llvm::ComputeValueTypes(const TargetLowering &TLI, const DataLayout &DL,
+                             Type *Ty,
+                             SmallVectorImpl<Type *> &ValueTys,
+                             SmallVectorImpl<TypeSize> *Offsets,
+                             TypeSize StartingOffset) {
+  assert((Ty->isScalableTy() == StartingOffset.isScalable() ||
+          StartingOffset.isZero()) &&
+         "Offset/TypeSize mismatch!");
+  // Given a struct type, recursively traverse the elements.
+  if (StructType *STy = dyn_cast<StructType>(Ty)) {
+    // If the Offsets aren't needed, don't query the struct layout. This allows
+    // us to support structs with scalable vectors for operations that don't
+    // need offsets.
+    const StructLayout *SL = Offsets ? DL.getStructLayout(STy) : nullptr;
+    for (StructType::element_iterator EB = STy->element_begin(),
+                                      EI = EB,
+                                      EE = STy->element_end();
+         EI != EE; ++EI) {
+      // Don't compute the element offset if we didn't get a StructLayout above.
+      TypeSize EltOffset =
+          SL ? SL->getElementOffset(EI - EB) : TypeSize::getZero();
+      ComputeValueTypes(TLI, DL, *EI, ValueTys, Offsets,
+                        StartingOffset + EltOffset);
+    }
+    return;
+  }
+  // Given an array type, recursively traverse the elements.
+  if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
+    Type *EltTy = ATy->getElementType();
+    TypeSize EltSize = DL.getTypeAllocSize(EltTy);
+    for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i)
+      ComputeValueTypes(TLI, DL, EltTy, ValueTys, Offsets,
+                        StartingOffset + i * EltSize);
+    return;
+  }
+  // Interpret void as zero return values.
+  if (Ty->isVoidTy())
+    return;
+  // Base case: this LLVM IR type cannot be flattened any further.
+  ValueTys.push_back(Ty);
+  if (Offsets)
+    Offsets->push_back(StartingOffset);
+}
+
 /// ComputeValueVTs - Given an LLVM IR type, compute a sequence of
 /// EVTs that represent all the individual underlying
 /// non-aggregate types that comprise it.
diff --git a/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp b/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp
index b431ecc8472671..69cb3d49a15c3d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp
@@ -994,8 +994,8 @@ bool FastISel::lowerCallTo(const CallInst *CI, MCSymbol *Symbol,
 bool FastISel::lowerCallTo(CallLoweringInfo &CLI) {
   // Handle the incoming return values from the call.
   CLI.clearIns();
-  SmallVector<EVT, 4> RetTys;
-  ComputeValueVTs(TLI, DL, CLI.RetTy, RetTys);
+  SmallVector<Type *, 4> RetTys;
+  ComputeValueTypes(TLI, DL, CLI.RetTy, RetTys);
 
   SmallVector<ISD::OutputArg, 4> Outs;
   GetReturnInfo(CLI.CallConv, CLI.RetTy, getReturnAttrs(CLI), Outs, TLI, DL);
@@ -1007,7 +1007,8 @@ bool FastISel::lowerCallTo(CallLoweringInfo &CLI) {
   if (!CanLowerReturn)
     return false;
 
-  for (EVT VT : RetTys) {
+  for (Type *RetTy : RetTys) {
+    EVT VT = TLI.getValueType(DL, RetTy);
     MVT RegisterVT = TLI.getRegisterType(CLI.RetTy->getContext(), VT);
     unsigned NumRegs = TLI.getNumRegisters(CLI.RetTy->getContext(), VT);
     for (unsigned i = 0; i != NumRegs; ++i) {
@@ -1015,6 +1016,10 @@ bool FastISel::lowerCallTo(CallLoweringInfo &CLI) {
       MyFlags.VT = RegisterVT;
       MyFlags.ArgVT = VT;
       MyFlags.Used = CLI.IsReturnValueUsed;
+      if (auto *PtrTy = dyn_cast<PointerType>(RetTy)) {
+        MyFlags.Flags.setPointer();
+        MyFlags.Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
+      }
       if (CLI.RetSExt)
         MyFlags.Flags.setSExt();
       if (CLI.RetZExt)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d729d448502d8..49c5487e49878c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2224,9 +2224,9 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) {
     Chain = DAG.getNode(ISD::TokenFactor, getCurSDLoc(),
                         MVT::Other, Chains);
   } else if (I.getNumOperands() != 0) {
-    SmallVector<EVT, 4> ValueVTs;
-    ComputeValueVTs(TLI, DL, I.getOperand(0)->getType(), ValueVTs);
-    unsigned NumValues = ValueVTs.size();
+    SmallVector<Type *, 4> ValueTys;
+    ComputeValueTypes(TLI, DL, I.getOperand(0)->getType(), ValueTys);
+    unsigned NumValues = ValueTys.size();
     if (NumValues) {
       SDValue RetOp = getValue(I.getOperand(0));
 
@@ -2246,7 +2246,7 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) {
       bool RetInReg = F->getAttributes().hasRetAttr(Attribute::InReg);
 
       for (unsigned j = 0; j != NumValues; ++j) {
-        EVT VT = ValueVTs[j];
+        EVT VT = TLI.getValueType(DL, ValueTys[j]);
 
         if (ExtendKind != ISD::ANY_EXTEND && VT.isInteger())
           VT = TLI.getTypeForExtReturn(Context, VT, ExtendKind);
@@ -2265,10 +2265,9 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) {
         if (RetInReg)
           Flags.setInReg();
 
-        if (I.getOperand(0)->getType()->isPointerTy()) {
+        if (auto *PtrTy = dyn_cast<PointerType>(ValueTys[j])) {
           Flags.setPointer();
-          Flags.setPointerAddrSpace(
-              cast<PointerType>(I.getOperand(0)->getType())->getAddressSpace());
+          Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
         }
 
         if (NeedsRegBlock) {
@@ -10971,25 +10970,29 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
   // Handle the incoming return values from the call.
   CLI.Ins.clear();
   Type *OrigRetTy = CLI.RetTy;
-  SmallVector<EVT, 4> RetTys;
+  SmallVector<Type *, 4> RetTys;
   SmallVector<TypeSize, 4> Offsets;
   auto &DL = CLI.DAG.getDataLayout();
-  ComputeValueVTs(*this, DL, CLI.RetTy, RetTys, &Offsets);
+  ComputeValueTypes(*this, DL, CLI.RetTy, RetTys, &Offsets);
+
+  SmallVector<EVT, 4> RetVTs;
+  for (Type *Ty : RetTys)
+    RetVTs.push_back(getValueType(DL, Ty));
 
   if (CLI.IsPostTypeLegalization) {
     // If we are lowering a libcall after legalization, split the return type.
-    SmallVector<EVT, 4> OldRetTys;
+    SmallVector<EVT, 4> OldRetVTs;
     SmallVector<TypeSize, 4> OldOffsets;
-    RetTys.swap(OldRetTys);
+    RetVTs.swap(OldRetVTs);
     Offsets.swap(OldOffsets);
 
-    for (size_t i = 0, e = OldRetTys.size(); i != e; ++i) {
-      EVT RetVT = OldRetTys[i];
+    for (size_t i = 0, e = OldRetVTs.size(); i != e; ++i) {
+      EVT RetVT = OldRetVTs[i];
       uint64_t Offset = OldOffsets[i];
       MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), RetVT);
       unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), RetVT);
       unsigned RegisterVTByteSZ = RegisterVT.getSizeInBits() / 8;
-      RetTys.append(NumRegs, RegisterVT);
+      RetVTs.append(NumRegs, RegisterVT);
       for (unsigned j = 0; j != NumRegs; ++j)
         Offsets.push_back(TypeSize::getFixed(Offset + j * RegisterVTByteSZ));
     }
@@ -11051,7 +11054,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
         if (I == RetTys.size() - 1)
           Flags.setInConsecutiveRegsLast();
       }
-      EVT VT = RetTys[I];
+      EVT VT = RetVTs[I];
       MVT RegisterVT = getRegisterTypeForCallingConv(CLI.RetTy->getContext(),
                                                      CLI.CallConv, VT);
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
@@ -11062,10 +11065,13 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
         MyFlags.VT = RegisterVT;
         MyFlags.ArgVT = VT;
         MyFlags.Used = CLI.IsReturnValueUsed;
-        if (CLI.RetTy->isPointerTy()) {
+        // Do not set the pointer flag for libcalls after legalization as
+        // RetTys does not match RetVTs in this case.
+        PointerType *PtrTy;
+        if (!CLI.IsPostTypeLegalization &&
+            (PtrTy = dyn_cast<PointerType>(RetTys[I]))) {
           MyFlags.Flags.setPointer();
-          MyFlags.Flags.setPointerAddrSpace(
-              cast<PointerType>(CLI.RetTy)->getAddressSpace());
+          MyFlags.Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
         }
         if (CLI.RetSExt)
           MyFlags.Flags.setSExt();
@@ -11096,17 +11102,17 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
   CLI.Outs.clear();
   CLI.OutVals.clear();
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    SmallVector<EVT, 4> ValueVTs;
-    ComputeValueVTs(*this, DL, Args[i].Ty, ValueVTs);
+    SmallVector<Type *, 4> ValueTys;
+    ComputeValueTypes(*this, DL, Args[i].Ty, ValueTys);
     // FIXME: Split arguments if CLI.IsPostTypeLegalization
     Type *FinalType = Args[i].Ty;
     if (Args[i].IsByVal)
       FinalType = Args[i].IndirectType;
     bool NeedsRegBlock = functionArgumentNeedsConsecutiveRegisters(
         FinalType, CLI.CallConv, CLI.IsVarArg, DL);
-    for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues;
+    for (unsigned Value = 0, NumValues = ValueTys.size(); Value != NumValues;
          ++Value) {
-      EVT VT = ValueVTs[Value];
+      EVT VT = getValueType(DL, ValueTys[Value]);
       Type *ArgTy = VT.getTypeForEVT(CLI.RetTy->getContext());
       SDValue Op = SDValue(Args[i].Node.getNode(),
                            Args[i].Node.getResNo() + Value);
@@ -11118,10 +11124,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
       const Align OriginalAlignment(getABIAlignmentForCallingConv(ArgTy, DL));
       Flags.setOrigAlign(OriginalAlignment);
 
-      if (Args[i].Ty->isPointerTy()) {
+      if (auto *PtrTy = dyn_cast<PointerType>(ValueTys[Value])) {
         Flags.setPointer();
-        Flags.setPointerAddrSpace(
-            cast<PointerType>(Args[i].Ty)->getAddressSpace());
+        Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
       }
       if (Args[i].IsZExt)
         Flags.setZExt();
@@ -11215,7 +11220,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
                 (CLI.RetTy->isPointerTy() && Args[i].Ty->isPointerTy() &&
                  CLI.RetTy->getPointerAddressSpace() ==
                      Args[i].Ty->getPointerAddressSpace())) &&
-               RetTys.size() == NumValues && "unexpected use of 'returned'");
+               RetVTs.size() == NumValues && "unexpected use of 'returned'");
         // Before passing 'returned' to the target lowering code, ensure that
         // either the register MVT and the actual EVT are the same size or that
         // the return value and argument are extended in the same way; in these
@@ -11303,7 +11308,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
     assert(PVTs.size() == 1 && "Pointers should fit in one register");
     EVT PtrVT = PVTs[0];
 
-    unsigned NumValues = RetTys.size();
+    unsigned NumValues = RetVTs.size();
     ReturnValues.resize(NumValues);
     SmallVector<SDValue, 4> Chains(NumValues);
 
@@ -11317,7 +11322,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
                           CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT),
                           SDNodeFlags::NoUnsignedWrap);
       SDValue L = CLI.DAG.getLoad(
-          RetTys[i], CLI.DL, CLI.Chain, Add,
+          RetVTs[i], CLI.DL, CLI.Chain, Add,
           MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),
                                             DemoteStackIdx, Offsets[i]),
           HiddenSRetAlign);
@@ -11335,7 +11340,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
     else if (CLI.RetZExt)
       AssertOp = ISD::AssertZext;
     unsigned CurReg = 0;
-    for (EVT VT : RetTys) {
+    for (EVT VT : RetVTs) {
       MVT RegisterVT = getRegisterTypeForCallingConv(CLI.RetTy->getContext(),
                                                      CLI.CallConv, VT);
       unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
@@ -11355,7 +11360,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
   }
 
   SDValue Res = CLI.DAG.getNode(ISD::MERGE_VALUES, CLI.DL,
-                                CLI.DAG.getVTList(RetTys), ReturnValues);
+                                CLI.DAG.getVTList(RetVTs), ReturnValues);
   return std::make_pair(Res, CLI.Chain);
 }
 
@@ -11647,8 +11652,8 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
   // Set up the incoming argument description vector.
   for (const Argument &Arg : F.args()) {
     unsigned ArgNo = Arg.getArgNo();
-    SmallVector<EVT, 4> ValueVTs;
-    ComputeValueVTs(*TLI, DAG.getDataLayout(), Arg.getType(), ValueVTs);
+    SmallVector<Type *, 4> ValueTys;
+    ComputeValueTypes(*TLI, DAG.getDataLayout(), Arg.getType(), ValueTys);
     bool isArgValueUsed = !Arg.use_empty();
     unsigned PartBase = 0;
     Type *FinalType = Arg.getType();
@@ -11656,17 +11661,15 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
       FinalType = Arg.getParamByValType();
     bool NeedsRegBlock = TLI->functionArgumentNeedsConsecutiveRegisters(
         FinalType, F.getCallingConv(), F.isVarArg(), DL);
-    for (unsigned Value = 0, NumValues = ValueVTs.size();
+    for (unsigned Value = 0, NumValues = ValueTys.size();
          Value != NumValues; ++Value) {
-      EVT VT = ValueVTs[Value];
+      EVT VT = TLI->getValueType(DAG.getDataLayout(), ValueTys[Value]);
       Type *ArgTy = VT.getTypeForEVT(*DAG.getContext());
       ISD::ArgFlagsTy Flags;
 
-
-      if (Arg.getType()->isPointerTy()) {
+      if (auto *PtrTy = dyn_cast<PointerType>(ValueTys[Value])) {
         Flags.setPointer();
-        Flags.setPointerAddrSpace(
-            cast<PointerType>(Arg.getType())->getAddressSpace());
+        Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
       }
       if (Arg.hasAttribute(Attribute::ZExt))
         Flags.setZExt();
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 392cfbdd21273d..5232082f00e924 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1659,13 +1659,12 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
                          AttributeList attr,
                          SmallVectorImpl<ISD::OutputArg> &Outs,
                          const TargetLowering &TLI, const DataLayout &DL) {
-  SmallVector<EVT, 4> ValueVTs;
-  ComputeValueVTs(TLI, DL, ReturnType, ValueVTs);
-  unsigned NumValues = ValueVTs.size();
-  if (NumValues == 0) return;
+  SmallVector<Type *, 4> ValueTys;
+  ComputeValueTypes(TLI, DL, ReturnType, ValueTys);
+  if (ValueTys.size() == 0) return;
 
-  for (unsigned j = 0, f = NumValues; j != f; ++j) {
-    EVT VT = ValueVTs[j];
+  for (Type *Ty : ValueTys) {
+    EVT VT = TLI.getValueType(DL, Ty);
     ISD::NodeType ExtendKind = ISD::ANY_EXTEND;
 
     if (attr.hasRetAttr(Attribute::SExt))
@@ -1686,6 +1685,11 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType,
     if (attr.hasRetAttr(Attribute::InReg))
       Flags.setInReg();
 
+    if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+      Flags.setPointer();
+      Flags.setPointerAddrSpace(PtrTy->getAddressSpace());
+    }
+
     // Propagate extension type if any
     if (attr.hasRetAttr(Attribute::SExt))
       Flags.setSExt();



More information about the llvm-commits mailing list