[llvm-branch-commits] [llvm] [NFCI][IR] Thread `DataLayout` through `ConstantFold`; fix CAZ extraction and aggregate collapse (PR #183209)

Shilei Tian via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Feb 24 15:55:35 PST 2026


https://github.com/shiltian created https://github.com/llvm/llvm-project/pull/183209

Prepare the constant folding infrastructure for the `ConstantPointerNull`
semantic change, where null may have a non-zero bit pattern.

Thread `const DataLayout *DL = nullptr` through `ConstantFoldCastInstruction`,
`ConstantFoldCompareInstruction`, and `ConstantFoldGetElementPtr`. When DL is
present and the null pointer is not zero for the relevant address space,
pointer-involving folds (e.g., ptrtoint null -> 0, icmp uge X null -> true)
are deferred to the DL-aware folder instead of producing incorrect results.
Without DL, behavior is unchanged.

Fix `ConstantAggregateZero` element extraction to return `getZeroValue` (not
`getNullValue`), ensuring CAZ always yields all-zero-bit elements regardless
of the address space's null pointer value.

Fix aggregate collapse checks to use `isZeroValue()` instead of `isNullValue()`.
This correctly prevents collapsing aggregates of FP -0.0 (non-zero bit
pattern) into `ConstantAggregateZero`, and will prevent incorrect collapse of
non-zero-null `ConstantPointerNull` after the semantic change.

>From a899485cd96fa2ba4b95663404df032a2e716628 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Sat, 14 Feb 2026 19:27:47 -0500
Subject: [PATCH] [NFCI][IR] Thread `DataLayout` through `ConstantFold`; fix
 CAZ extraction and aggregate collapse

Prepare the constant folding infrastructure for the `ConstantPointerNull`
semantic change, where null may have a non-zero bit pattern.

Thread `const DataLayout *DL = nullptr` through `ConstantFoldCastInstruction`,
`ConstantFoldCompareInstruction`, and `ConstantFoldGetElementPtr`. When DL is
present and the null pointer is not zero for the relevant address space,
pointer-involving folds (e.g., ptrtoint null -> 0, icmp uge X null -> true)
are deferred to the DL-aware folder instead of producing incorrect results.
Without DL, behavior is unchanged.

Fix `ConstantAggregateZero` element extraction to return `getZeroValue` (not
`getNullValue`), ensuring CAZ always yields all-zero-bit elements regardless
of the address space's null pointer value.

Fix aggregate collapse checks to use `isZeroValue()` instead of `isNullValue()`.
This correctly prevents collapsing aggregates of FP -0.0 (non-zero bit
pattern) into `ConstantAggregateZero`, and will prevent incorrect collapse of
non-zero-null `ConstantPointerNull` after the semantic change.
---
 llvm/include/llvm/IR/ConstantFold.h           |  15 +-
 llvm/lib/Analysis/ConstantFolding.cpp         |  20 +-
 llvm/lib/Analysis/InstructionSimplify.cpp     |   2 +-
 llvm/lib/IR/ConstantFold.cpp                  |  70 +++++--
 llvm/lib/IR/Constants.cpp                     |  22 +--
 .../RISCV/RISCVGatherScatterLowering.cpp      |   6 +-
 llvm/unittests/IR/ConstantsTest.cpp           | 187 +++++++++++++++++-
 7 files changed, 278 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/IR/ConstantFold.h b/llvm/include/llvm/IR/ConstantFold.h
index 4056f1feb4dd3..3eb8c66d0aacf 100644
--- a/llvm/include/llvm/IR/ConstantFold.h
+++ b/llvm/include/llvm/IR/ConstantFold.h
@@ -29,14 +29,15 @@ namespace llvm {
 template <typename T> class ArrayRef;
 class Value;
 class Constant;
+class DataLayout;
 class Type;
 
 // Constant fold various types of instruction...
 LLVM_ABI Constant *
 ConstantFoldCastInstruction(unsigned opcode, ///< The opcode of the cast
                             Constant *V,     ///< The source constant
-                            Type *DestTy     ///< The destination type
-);
+                            Type *DestTy,    ///< The destination type
+                            const DataLayout *DL = nullptr);
 
 /// Attempt to constant fold a select instruction with the specified
 /// operands. The constant result is returned if successful; if not, null is
@@ -80,12 +81,12 @@ LLVM_ABI Constant *ConstantFoldInsertValueInstruction(Constant *Agg,
 LLVM_ABI Constant *ConstantFoldUnaryInstruction(unsigned Opcode, Constant *V);
 LLVM_ABI Constant *ConstantFoldBinaryInstruction(unsigned Opcode, Constant *V1,
                                                  Constant *V2);
-LLVM_ABI Constant *ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
-                                                  Constant *C1, Constant *C2);
 LLVM_ABI Constant *
-ConstantFoldGetElementPtr(Type *Ty, Constant *C,
-                          std::optional<ConstantRange> InRange,
-                          ArrayRef<Value *> Idxs);
+ConstantFoldCompareInstruction(CmpInst::Predicate Predicate, Constant *C1,
+                               Constant *C2, const DataLayout *DL = nullptr);
+LLVM_ABI Constant *ConstantFoldGetElementPtr(
+    Type *Ty, Constant *C, std::optional<ConstantRange> InRange,
+    ArrayRef<Value *> Idxs, const DataLayout *DL = nullptr);
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 7573afe423ec9..ae4554b0ab837 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1308,7 +1308,7 @@ Constant *llvm::ConstantFoldCompareInstOperands(
       return nullptr;
   }
 
-  return ConstantFoldCompareInstruction(Predicate, Ops0, Ops1);
+  return ConstantFoldCompareInstruction(Predicate, Ops0, Ops1, &DL);
 }
 
 Constant *llvm::ConstantFoldUnaryOpOperand(unsigned Opcode, Constant *Op,
@@ -1579,9 +1579,25 @@ Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C,
     return FoldBitCast(C, DestTy, DL);
   }
 
+  // DL-aware null folding for pointer casts. ConstantExpr::getCast below does
+  // not have DataLayout, so handle the null case here to ensure casts involving
+  // null pointers (e.g., inttoptr(0) -> null, ptrtoint(null) -> 0) still fold
+  // correctly when DataLayout confirms null is zero for the address space.
+  if (C->isNullValue() && !DestTy->isX86_AMXTy() &&
+      Opcode != Instruction::AddrSpaceCast) {
+    bool SrcIsPtr = C->getType()->isPtrOrPtrVectorTy();
+    bool DstIsPtr = DestTy->isPtrOrPtrVectorTy();
+    if (SrcIsPtr || DstIsPtr) {
+      unsigned AS = SrcIsPtr ? C->getType()->getPointerAddressSpace()
+                             : DestTy->getPointerAddressSpace();
+      if (DL.isNullPointerAllZeroes(AS))
+        return Constant::getNullValue(DestTy);
+    }
+  }
+
   if (ConstantExpr::isDesirableCastOp(Opcode))
     return ConstantExpr::getCast(Opcode, C, DestTy);
-  return ConstantFoldCastInstruction(Opcode, C, DestTy);
+  return ConstantFoldCastInstruction(Opcode, C, DestTy, &DL);
 }
 
 Constant *llvm::ConstantFoldIntegerCast(Constant *C, Type *DestTy,
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 3d5ee74c0e2e8..94145f74d8531 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5353,7 +5353,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
 
   if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
     return ConstantFoldGetElementPtr(SrcTy, cast<Constant>(Ptr), std::nullopt,
-                                     Indices);
+                                     Indices, &Q.DL);
 
   auto *CE =
       ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ptr), Indices, NW);
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 87a70391fbec4..ba47600dcf0e1 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalAlias.h"
@@ -122,14 +123,16 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
 }
 
 static Constant *foldMaybeUndesirableCast(unsigned opc, Constant *V,
-                                          Type *DestTy) {
+                                          Type *DestTy,
+                                          const DataLayout *DL = nullptr) {
   return ConstantExpr::isDesirableCastOp(opc)
              ? ConstantExpr::getCast(opc, V, DestTy)
-             : ConstantFoldCastInstruction(opc, V, DestTy);
+             : ConstantFoldCastInstruction(opc, V, DestTy, DL);
 }
 
 Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
-                                            Type *DestTy) {
+                                            Type *DestTy,
+                                            const DataLayout *DL) {
   if (isa<PoisonValue>(V))
     return PoisonValue::get(DestTy);
 
@@ -144,8 +147,22 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   }
 
   if (V->isNullValue() && !DestTy->isX86_AMXTy() &&
-      opc != Instruction::AddrSpaceCast)
+      opc != Instruction::AddrSpaceCast) {
+    // If the source or destination involves pointers and DL tells us that
+    // null is not zero for the relevant address space, we cannot fold here.
+    // Defer to the DL-aware folding in Analysis/ConstantFolding.cpp.
+    if (DL) {
+      bool SrcIsPtr = V->getType()->isPtrOrPtrVectorTy();
+      bool DstIsPtr = DestTy->isPtrOrPtrVectorTy();
+      if (SrcIsPtr || DstIsPtr) {
+        unsigned AS = SrcIsPtr ? V->getType()->getPointerAddressSpace()
+                               : DestTy->getPointerAddressSpace();
+        if (!DL->isNullPointerAllZeroes(AS))
+          return nullptr;
+      }
+    }
     return Constant::getNullValue(DestTy);
+  }
 
   // If the cast operand is a constant expression, there's a few things we can
   // do to try to simplify it.
@@ -153,7 +170,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     if (CE->isCast()) {
       // Try hard to fold cast of cast because they are often eliminable.
       if (unsigned newOpc = foldConstantCastPair(opc, CE, DestTy))
-        return foldMaybeUndesirableCast(newOpc, CE->getOperand(0), DestTy);
+        return foldMaybeUndesirableCast(newOpc, CE->getOperand(0), DestTy, DL);
     }
   }
 
@@ -167,7 +184,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     Type *DstEltTy = DestVecTy->getElementType();
     // Fast path for splatted constants.
     if (Constant *Splat = V->getSplatValue()) {
-      Constant *Res = foldMaybeUndesirableCast(opc, Splat, DstEltTy);
+      Constant *Res = foldMaybeUndesirableCast(opc, Splat, DstEltTy, DL);
       if (!Res)
         return nullptr;
       return ConstantVector::getSplat(
@@ -181,7 +198,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
                   e = cast<FixedVectorType>(V->getType())->getNumElements();
          i != e; ++i) {
       Constant *C = ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i));
-      Constant *Casted = foldMaybeUndesirableCast(opc, C, DstEltTy);
+      Constant *Casted = foldMaybeUndesirableCast(opc, C, DstEltTy, DL);
       if (!Casted)
         return nullptr;
       res.push_back(Casted);
@@ -1101,7 +1118,8 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2) {
 }
 
 Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
-                                               Constant *C1, Constant *C2) {
+                                               Constant *C1, Constant *C2,
+                                               const DataLayout *DL) {
   Type *ResultTy;
   if (VectorType *VT = dyn_cast<VectorType>(C1->getType()))
     ResultTy = VectorType::get(Type::getInt1Ty(C1->getContext()),
@@ -1139,14 +1157,25 @@ Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
   }
 
   if (C2->isNullValue()) {
-    // The caller is expected to commute the operands if the constant expression
-    // is C2.
-    // C1 >= 0 --> true
-    if (Predicate == ICmpInst::ICMP_UGE)
-      return Constant::getAllOnesValue(ResultTy);
-    // C1 < 0 --> false
-    if (Predicate == ICmpInst::ICMP_ULT)
-      return Constant::getNullValue(ResultTy);
+    // If DL tells us that null is not zero for this pointer's address space,
+    // we cannot rely on the null value being the unsigned minimum. Defer.
+    bool CanFoldNullCmp = true;
+    if (DL && C2->getType()->isPtrOrPtrVectorTy()) {
+      unsigned AS = C2->getType()->getPointerAddressSpace();
+      if (!DL->isNullPointerAllZeroes(AS))
+        CanFoldNullCmp = false;
+    }
+
+    if (CanFoldNullCmp) {
+      // The caller is expected to commute the operands if the constant
+      // expression is C2.
+      // C1 >= 0 --> true
+      if (Predicate == ICmpInst::ICMP_UGE)
+        return Constant::getAllOnesValue(ResultTy);
+      // C1 < 0 --> false
+      if (Predicate == ICmpInst::ICMP_ULT)
+        return Constant::getNullValue(ResultTy);
+    }
   }
 
   // If the comparison is a comparison between two i1's, simplify it.
@@ -1177,7 +1206,7 @@ Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
     if (Constant *C1Splat = C1->getSplatValue())
       if (Constant *C2Splat = C2->getSplatValue())
         if (Constant *Elt =
-                ConstantFoldCompareInstruction(Predicate, C1Splat, C2Splat))
+                ConstantFoldCompareInstruction(Predicate, C1Splat, C2Splat, DL))
           return ConstantVector::getSplat(C1VTy->getElementCount(), Elt);
 
     // Do not iterate on scalable vector. The number of elements is unknown at
@@ -1196,7 +1225,7 @@ Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
           ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, I));
       Constant *C2E =
           ConstantExpr::getExtractElement(C2, ConstantInt::get(Ty, I));
-      Constant *Elt = ConstantFoldCompareInstruction(Predicate, C1E, C2E);
+      Constant *Elt = ConstantFoldCompareInstruction(Predicate, C1E, C2E, DL);
       if (!Elt)
         return nullptr;
 
@@ -1308,7 +1337,7 @@ Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
       // other way if possible.
       // Also, if C1 is null and C2 isn't, flip them around.
       Predicate = ICmpInst::getSwappedPredicate(Predicate);
-      return ConstantFoldCompareInstruction(Predicate, C2, C1);
+      return ConstantFoldCompareInstruction(Predicate, C2, C1, DL);
     }
   }
   return nullptr;
@@ -1316,7 +1345,8 @@ Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
 
 Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
                                           std::optional<ConstantRange> InRange,
-                                          ArrayRef<Value *> Idxs) {
+                                          ArrayRef<Value *> Idxs,
+                                          const DataLayout *DL) {
   if (Idxs.empty()) return C;
 
   Type *GEPTy = GetElementPtrInst::getGEPReturnType(
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5a011f50ef94b..53a627bd6a980 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1163,12 +1163,12 @@ void ConstantFP::destroyConstantImpl() {
 
 Constant *ConstantAggregateZero::getSequentialElement() const {
   if (auto *AT = dyn_cast<ArrayType>(getType()))
-    return Constant::getNullValue(AT->getElementType());
-  return Constant::getNullValue(cast<VectorType>(getType())->getElementType());
+    return Constant::getZeroValue(AT->getElementType());
+  return Constant::getZeroValue(cast<VectorType>(getType())->getElementType());
 }
 
 Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
-  return Constant::getNullValue(getType()->getStructElementType(Elt));
+  return Constant::getZeroValue(getType()->getStructElementType(Elt));
 }
 
 Constant *ConstantAggregateZero::getElementValue(Constant *C) const {
@@ -1368,7 +1368,7 @@ Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
   if (isa<UndefValue>(C) && rangeOnlyContains(V.begin(), V.end(), C))
     return UndefValue::get(Ty);
 
-  if (C->isNullValue() && rangeOnlyContains(V.begin(), V.end(), C))
+  if (C->isZeroValue() && rangeOnlyContains(V.begin(), V.end(), C))
     return ConstantAggregateZero::get(Ty);
 
   // Check to see if all of the elements are ConstantFP or ConstantInt and if
@@ -1419,11 +1419,11 @@ Constant *ConstantStruct::get(StructType *ST, ArrayRef<Constant*> V) {
   if (!V.empty()) {
     isUndef = isa<UndefValue>(V[0]);
     isPoison = isa<PoisonValue>(V[0]);
-    isZero = V[0]->isNullValue();
+    isZero = V[0]->isZeroValue();
     // PoisonValue inherits UndefValue, so its check is not necessary.
     if (isUndef || isZero) {
       for (Constant *C : V) {
-        if (!C->isNullValue())
+        if (!C->isZeroValue())
           isZero = false;
         if (!isa<PoisonValue>(C))
           isPoison = false;
@@ -1464,7 +1464,7 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
   // If this is an all-undef or all-zero vector, return a
   // ConstantAggregateZero or UndefValue.
   Constant *C = V[0];
-  bool isZero = C->isNullValue();
+  bool isZero = C->isZeroValue();
   bool isUndef = isa<UndefValue>(C);
   bool isPoison = isa<PoisonValue>(C);
   bool isSplatFP = UseConstantFPForFixedLengthSplat && isa<ConstantFP>(C);
@@ -1535,7 +1535,7 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
 
   Type *VTy = VectorType::get(V->getType(), EC);
 
-  if (V->isNullValue())
+  if (V->isZeroValue())
     return ConstantAggregateZero::get(VTy);
   if (isa<PoisonValue>(V))
     return PoisonValue::get(VTy);
@@ -1745,7 +1745,7 @@ Constant *Constant::getSplatValue(bool AllowPoison) const {
   if (isa<PoisonValue>(this))
     return PoisonValue::get(cast<VectorType>(getType())->getElementType());
   if (isa<ConstantAggregateZero>(this))
-    return getNullValue(cast<VectorType>(getType())->getElementType());
+    return getZeroValue(cast<VectorType>(getType())->getElementType());
   if (auto *CI = dyn_cast<ConstantInt>(this))
     return ConstantInt::get(getContext(), CI->getValue());
   if (auto *CFP = dyn_cast<ConstantFP>(this))
@@ -3350,7 +3350,7 @@ Value *ConstantArray::handleOperandChangeImpl(Value *From, Value *To) {
     AllSame &= Val == ToC;
   }
 
-  if (AllSame && ToC->isNullValue())
+  if (AllSame && ToC->isZeroValue())
     return ConstantAggregateZero::get(getType());
 
   if (AllSame && isa<UndefValue>(ToC))
@@ -3390,7 +3390,7 @@ Value *ConstantStruct::handleOperandChangeImpl(Value *From, Value *To) {
     AllSame &= Val == ToC;
   }
 
-  if (AllSame && ToC->isNullValue())
+  if (AllSame && ToC->isZeroValue())
     return ConstantAggregateZero::get(getType());
 
   if (AllSame && isa<UndefValue>(ToC))
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 25b5af8324e64..36088fe96d3a9 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -421,9 +421,11 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
     if (!VecIndexC)
       return std::make_pair(nullptr, nullptr);
     if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
-      VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
+      VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC,
+                                             VecIntPtrTy, DL);
     else
-      VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
+      VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC,
+                                             VecIntPtrTy, DL);
   }
 
   // Handle the non-recursive case.  This is what we see if the vectorizer
diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index 96730dbf16758..dcf9d980c7413 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -989,7 +989,192 @@ TEST(ConstantsTest, ZeroValueAPIs) {
             Constant::getNullValue(StructTy));
 
   // TODO: getNullValue slow path for aggregates with non-zero-null pointers is
-  // deferred to PR 3 testing (requires aggregate collapse fix).
+  // deferred to PR 4 testing (requires ConstantPointerNull semantic change).
+}
+
+TEST(ConstantsTest, AggregateCollapseAndCAZExtraction) {
+  LLVMContext Context;
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  Type *FloatTy = Type::getFloatTy(Context);
+  PointerType *PtrTy = PointerType::get(Context, 0);
+
+  // --- ConstantAggregateZero element extraction returns getZeroValue ---
+  auto *ArrTy = ArrayType::get(Int32Ty, 3);
+  auto *CAZ = ConstantAggregateZero::get(ArrTy);
+  Constant *Elt = CAZ->getSequentialElement();
+  EXPECT_EQ(Elt, Constant::getZeroValue(Int32Ty));
+  // For pointer element types.
+  auto *PtrArrTy = ArrayType::get(PtrTy, 2);
+  auto *PtrCAZ = ConstantAggregateZero::get(PtrArrTy);
+  Constant *PtrElt = PtrCAZ->getSequentialElement();
+  EXPECT_EQ(PtrElt, Constant::getZeroValue(PtrTy));
+
+  // Struct element extraction.
+  auto *StructTy = StructType::get(Int32Ty, PtrTy, FloatTy);
+  auto *StructCAZ = ConstantAggregateZero::get(StructTy);
+  EXPECT_EQ(StructCAZ->getStructElement(0), Constant::getZeroValue(Int32Ty));
+  EXPECT_EQ(StructCAZ->getStructElement(1), Constant::getZeroValue(PtrTy));
+  EXPECT_EQ(StructCAZ->getStructElement(2), Constant::getZeroValue(FloatTy));
+
+  // --- Zero-valued aggregates collapse to ConstantAggregateZero ---
+  Constant *ZeroI32 = Constant::getZeroValue(Int32Ty);
+  Constant *ZeroFloat = Constant::getZeroValue(FloatTy);
+  Constant *ZeroPtr = Constant::getZeroValue(PtrTy);
+
+  // Array of zero ints collapses.
+  Constant *ZeroArr = ConstantArray::get(ArrTy, {ZeroI32, ZeroI32, ZeroI32});
+  EXPECT_TRUE(isa<ConstantAggregateZero>(ZeroArr));
+
+  // Vector of zero ints collapses.
+  Constant *ZeroVec = ConstantVector::get({ZeroI32, ZeroI32, ZeroI32, ZeroI32});
+  EXPECT_TRUE(isa<ConstantAggregateZero>(ZeroVec));
+
+  // Struct of zeros collapses.
+  Constant *ZeroStruct =
+      ConstantStruct::get(StructTy, {ZeroI32, ZeroPtr, ZeroFloat});
+  EXPECT_TRUE(isa<ConstantAggregateZero>(ZeroStruct));
+
+  // Splat of zero collapses.
+  Constant *SplatZero =
+      ConstantVector::getSplat(ElementCount::getFixed(4), ZeroI32);
+  EXPECT_TRUE(isa<ConstantAggregateZero>(SplatZero));
+
+  // --- FP -0.0 does NOT collapse to ConstantAggregateZero ---
+  // -0.0 has a non-zero bit pattern (sign bit set), so it must not collapse.
+  Constant *NegZeroFP = ConstantFP::get(
+      FloatTy, APFloat::getZero(APFloat::IEEEsingle(), /*Negative=*/true));
+  EXPECT_NE(NegZeroFP, Constant::getZeroValue(FloatTy));
+
+  auto *FloatArrTy = ArrayType::get(FloatTy, 2);
+  Constant *NegZeroArr = ConstantArray::get(FloatArrTy, {NegZeroFP, NegZeroFP});
+  EXPECT_FALSE(isa<ConstantAggregateZero>(NegZeroArr));
+
+  Constant *NegZeroVec = ConstantVector::get({NegZeroFP, NegZeroFP});
+  EXPECT_FALSE(isa<ConstantAggregateZero>(NegZeroVec));
+
+  auto *FloatStructTy = StructType::get(FloatTy, FloatTy);
+  Constant *NegZeroStruct =
+      ConstantStruct::get(FloatStructTy, {NegZeroFP, NegZeroFP});
+  EXPECT_FALSE(isa<ConstantAggregateZero>(NegZeroStruct));
+
+  Constant *NegZeroSplat =
+      ConstantVector::getSplat(ElementCount::getFixed(4), NegZeroFP);
+  EXPECT_FALSE(isa<ConstantAggregateZero>(NegZeroSplat));
+
+  // --- getSplatValue for CAZ returns getZeroValue ---
+  auto *IntVecTy = FixedVectorType::get(Int32Ty, 4);
+  auto *IntVecCAZ = ConstantAggregateZero::get(IntVecTy);
+  Constant *SplatVal = IntVecCAZ->getSplatValue();
+  EXPECT_EQ(SplatVal, Constant::getZeroValue(Int32Ty));
+
+  auto *PtrVecTy = FixedVectorType::get(PtrTy, 2);
+  auto *PtrVecCAZ = ConstantAggregateZero::get(PtrVecTy);
+  Constant *PtrSplatVal = PtrVecCAZ->getSplatValue();
+  EXPECT_EQ(PtrSplatVal, Constant::getZeroValue(PtrTy));
+}
+
+TEST(ConstantsTest, ConstantFoldCastWithDL) {
+  LLVMContext Context;
+  // A DataLayout where AS 1 has all-ones null pointer.
+  DataLayout AllOnesDL("e-po1:64:64");
+  // A DataLayout where all address spaces have zero null (the default).
+  DataLayout DefaultDL("e-p:64:64");
+
+  Type *Int64Ty = Type::getInt64Ty(Context);
+  PointerType *PtrTy0 = PointerType::get(Context, 0);
+  PointerType *PtrTy1 = PointerType::get(Context, 1);
+
+  // --- Without DL, null pointer casts fold normally ---
+  Constant *NullPtr0 = ConstantPointerNull::get(PtrTy0);
+  Constant *NullPtr1 = ConstantPointerNull::get(PtrTy1);
+
+  // ptrtoint(null AS0) -> 0 (no DL)
+  Constant *Result =
+      ConstantFoldCastInstruction(Instruction::PtrToInt, NullPtr0, Int64Ty);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isNullValue());
+
+  // ptrtoint(null AS1) -> 0 (no DL, backward compat)
+  Result =
+      ConstantFoldCastInstruction(Instruction::PtrToInt, NullPtr1, Int64Ty);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isNullValue());
+
+  // --- With DefaultDL, null pointer casts still fold (AS 0 is zero null) ---
+  Result = ConstantFoldCastInstruction(Instruction::PtrToInt, NullPtr0, Int64Ty,
+                                       &DefaultDL);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isNullValue());
+
+  // --- With AllOnesDL, AS 1 null cast is deferred ---
+  // ptrtoint(null AS1) should return nullptr (defer to DL-aware folder).
+  Result = ConstantFoldCastInstruction(Instruction::PtrToInt, NullPtr1, Int64Ty,
+                                       &AllOnesDL);
+  EXPECT_EQ(Result, nullptr);
+
+  // inttoptr(0, AS1) should also be deferred.
+  Constant *ZeroI64 = ConstantInt::get(Int64Ty, 0);
+  Result = ConstantFoldCastInstruction(Instruction::IntToPtr, ZeroI64, PtrTy1,
+                                       &AllOnesDL);
+  EXPECT_EQ(Result, nullptr);
+
+  // But AS 0 with AllOnesDL still folds fine.
+  Result = ConstantFoldCastInstruction(Instruction::PtrToInt, NullPtr0, Int64Ty,
+                                       &AllOnesDL);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isNullValue());
+}
+
+TEST(ConstantsTest, ConstantFoldCompareWithDL) {
+  LLVMContext Context;
+  DataLayout AllOnesDL("e-po1:64:64");
+  DataLayout DefaultDL("e-p:64:64");
+
+  PointerType *PtrTy0 = PointerType::get(Context, 0);
+  PointerType *PtrTy1 = PointerType::get(Context, 1);
+
+  Constant *NullPtr0 = ConstantPointerNull::get(PtrTy0);
+  Constant *NullPtr1 = ConstantPointerNull::get(PtrTy1);
+
+  // Create a non-null pointer constant expression for comparison.
+  Type *Int64Ty = Type::getInt64Ty(Context);
+  Constant *One = ConstantInt::get(Int64Ty, 1);
+  Constant *NonNullPtr0 = ConstantExpr::getIntToPtr(One, PtrTy0);
+  Constant *NonNullPtr1 = ConstantExpr::getIntToPtr(One, PtrTy1);
+
+  // --- Without DL, unsigned null comparisons fold ---
+  // ptr >= null -> true (always, since null is the unsigned minimum)
+  Constant *Result =
+      ConstantFoldCompareInstruction(CmpInst::ICMP_UGE, NonNullPtr0, NullPtr0);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isAllOnesValue());
+
+  // ptr < null -> false
+  Result =
+      ConstantFoldCompareInstruction(CmpInst::ICMP_ULT, NonNullPtr0, NullPtr0);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isNullValue());
+
+  // --- With AllOnesDL, AS 1 unsigned null comparisons are deferred ---
+  Result = ConstantFoldCompareInstruction(CmpInst::ICMP_UGE, NonNullPtr1,
+                                          NullPtr1, &AllOnesDL);
+  EXPECT_EQ(Result, nullptr);
+
+  Result = ConstantFoldCompareInstruction(CmpInst::ICMP_ULT, NonNullPtr1,
+                                          NullPtr1, &AllOnesDL);
+  EXPECT_EQ(Result, nullptr);
+
+  // --- With AllOnesDL, AS 0 still folds (zero null) ---
+  Result = ConstantFoldCompareInstruction(CmpInst::ICMP_UGE, NonNullPtr0,
+                                          NullPtr0, &AllOnesDL);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isAllOnesValue());
+
+  // --- With DefaultDL, everything folds normally ---
+  Result = ConstantFoldCompareInstruction(CmpInst::ICMP_UGE, NonNullPtr0,
+                                          NullPtr0, &DefaultDL);
+  ASSERT_NE(Result, nullptr);
+  EXPECT_TRUE(Result->isAllOnesValue());
 }
 
 } // end anonymous namespace



More information about the llvm-branch-commits mailing list