[clang] [llvm] [IR] Add getelementptr nusw and nuw flags (PR #90824)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 23:30:32 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/90824

This implements the `nusw` and `nuw` flags for `getelementptr` as proposed at https://discourse.llvm.org/t/rfc-add-nusw-and-nuw-flags-for-getelementptr/78672.

There are a bunch of places annotated with `TODO(gep_nowrap)`, where I've had to touch code but opted to not infer or precisely preserve the new flags, so as to keep this as NFC as possible and make sure any changes of that kind get test coverage when they are made.

>From eb27a1b94ec807323d204b51d5c01cc22056e1c7 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 2 May 2024 12:11:18 +0900
Subject: [PATCH 1/2] Add support for getelementptr nusw and nuw

---
 llvm/docs/LangRef.rst                         | 57 ++++++++++++------
 llvm/docs/ReleaseNotes.rst                    |  1 +
 llvm/include/llvm/AsmParser/LLToken.h         |  1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  8 +++
 llvm/include/llvm/IR/Instructions.h           | 14 +++++
 llvm/include/llvm/IR/Operator.h               | 26 +++++++-
 llvm/lib/AsmParser/LLLexer.cpp                |  1 +
 llvm/lib/AsmParser/LLParser.cpp               | 21 ++++++-
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     | 20 +++++--
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     | 11 +++-
 llvm/lib/IR/AsmWriter.cpp                     |  4 ++
 llvm/lib/IR/Instruction.cpp                   | 22 +++++--
 llvm/lib/IR/Instructions.cpp                  | 16 +++++
 llvm/lib/IR/Operator.cpp                      |  3 +-
 .../Scalar/SeparateConstOffsetFromGEP.cpp     | 13 ++++
 .../Transforms/Utils/FunctionComparator.cpp   |  6 ++
 llvm/lib/Transforms/Vectorize/VPlan.h         |  5 ++
 llvm/test/Assembler/flags.ll                  | 43 +++++++++++++
 llvm/test/Transforms/InstCombine/freeze.ll    | 22 +++++++
 llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 60 +++++++++++++++++++
 llvm/test/tools/llvm-reduce/reduce-flags.ll   | 18 ++++--
 .../deltas/ReduceInstructionFlags.cpp         |  4 ++
 22 files changed, 338 insertions(+), 38 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 6291a4e57919a5..7aeed82ab84df7 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11180,6 +11180,8 @@ Syntax:
 
       <result> = getelementptr <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inbounds <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nusw <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nuw <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inrange(S,E) <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr <ty>, <N x ptr> <ptrval>, <vector index type> <idx>
 
@@ -11295,27 +11297,46 @@ memory though, even if it happens to point into allocated storage. See the
 :ref:`Pointer Aliasing Rules <pointeraliasing>` section for more
 information.
 
-If the ``inbounds`` keyword is present, the result value of a
-``getelementptr`` with any non-zero indices is a
-:ref:`poison value <poisonvalues>` if one of the following rules is violated:
-
-*  The base pointer has an *in bounds* address of an allocated object, which
+The ``getelementptr`` instruction may have a number of attributes that impose
+additional rules. If any of the rules are violated, the result value is a
+:ref:`poison value <poisonvalues>`. In cases where the base is a vector of
+pointers, the attributes apply to each computation element-wise.
+
+For ``nusw`` (no unsigned signed wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the signed value
+   (``trunc nsw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in a signed sense (``mul nsw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in a signed sense (``add nsw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, interpreted as
+   a signed number, does not wrap the index type.
+
+For ``nuw`` (no unsigned wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the unsigned value
+   (``trunc nuw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in an unsigned sense (``mul nuw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in an unsigned sense (``add nuw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, also interpreted as
+   an unsigned number, does not wrap the index type (``add nuw``).
+
+For ``inbounds`` all rules of the ``nusw`` attribute apply. Additionally,
+if the ``getelementptr`` has any non-zero indices, the following rules apply:
+
+ * The base pointer has an *in bounds* address of an allocated object, which
    means that it points into an allocated object, or to its end. Note that the
    object does not have to be live anymore; being in-bounds of a deallocated
    object is sufficient.
-*  If the type of an index is larger than the pointer index type, the
-   truncation to the pointer index type preserves the signed value.
-*  The multiplication of an index by the type size does not wrap the pointer
-   index type in a signed sense (``nsw``).
-*  The successive addition of each offset (without adding the base address) does
-   not wrap the pointer index type in a signed sense (``nsw``).
-*  The successive addition of the current address, interpreted as an unsigned
-   number, and each offset, interpreted as a signed number, does not wrap the
-   unsigned address space and remains *in bounds* of the allocated object.
-   As a corollary, if the added offset is non-negative, the addition does not
-   wrap in an unsigned sense (``nuw``).
-*  In cases where the base is a vector of pointers, the ``inbounds`` keyword
-   applies to each of the computations element-wise.
+ * During the successive addition of offsets to the address, the resulting
+   pointer must remain *in bounds* of the allocated object at each step.
 
 Note that ``getelementptr`` with all-zero indices is always considered to be
 ``inbounds``, even if the base pointer does not point to an allocated object.
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index d8cc667723f554..412b85456cdbeb 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -51,6 +51,7 @@ Changes to the LLVM IR
 ----------------------
 
 * Added Memory Model Relaxation Annotations (MMRAs).
+* Added ``nusw`` and ``nuw`` flags to ``getelementptr`` instruction.
 * Renamed ``llvm.experimental.vector.reverse`` intrinsic to ``llvm.vector.reverse``.
 * Renamed ``llvm.experimental.vector.splice`` intrinsic to ``llvm.vector.splice``.
 * Renamed ``llvm.experimental.vector.interleave2`` intrinsic to ``llvm.vector.interleave2``.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0cbcdcd9ffac77..df61ec6ed30e0b 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -109,6 +109,7 @@ enum Kind {
   kw_fast,
   kw_nuw,
   kw_nsw,
+  kw_nusw,
   kw_exact,
   kw_disjoint,
   kw_inbounds,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..1fce358a92e548 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -524,6 +524,14 @@ enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
 /// PossiblyDisjointInst's SubclassOptionalData contents.
 enum PossiblyDisjointInstOptionalFlags { PDI_DISJOINT = 0 };
 
+/// GetElementPtrOptionalFlags - Flags for serializing
+/// GEPOperator's SubclassOptionalData contents.
+enum GetElementPtrOptionalFlags {
+  GEP_INBOUNDS = 0,
+  GEP_NUSW = 1,
+  GEP_NUW = 2,
+};
+
 /// Encoded AtomicOrdering values.
 enum AtomicOrderingCodes {
   ORDERING_NOTATOMIC = 0,
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index d7ec3c16bec21c..8c0db7b7bfdb2e 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1171,9 +1171,23 @@ class GetElementPtrInst : public Instruction {
   /// See LangRef.html for the meaning of inbounds on a getelementptr.
   void setIsInBounds(bool b = true);
 
+  /// Set or clear the nusw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nusw on a getelementptr.
+  void setHasNoUnsignedSignedWrap(bool B = true);
+
+  /// Set or clear the nuw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nuw on a getelementptr.
+  void setHasNoUnsignedWrap(bool B = true);
+
   /// Determine whether the GEP has the inbounds flag.
   bool isInBounds() const;
 
+  /// Determine whether the GEP has the nusw flag.
+  bool hasNoUnsignedSignedWrap() const;
+
+  /// Determine whether the GEP has the nuw flag.
+  bool hasNoUnsignedWrap() const;
+
   /// Accumulate the constant address offset of this GEP if possible.
   ///
   /// This routine accepts an APInt into which it will accumulate the constant
diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index b2307948bbbc4f..637542397cd5d8 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -405,11 +405,27 @@ class GEPOperator
 
   enum {
     IsInBounds = (1 << 0),
+    HasNoUnsignedSignedWrap = (1 << 1),
+    HasNoUnsignedWrap = (1 << 2),
   };
 
   void setIsInBounds(bool B) {
+    // Also set nusw when inbounds is set.
+    SubclassOptionalData = (SubclassOptionalData & ~IsInBounds) |
+                           (B * (IsInBounds | HasNoUnsignedSignedWrap));
+  }
+
+  void setHasNoUnsignedSignedWrap(bool B) {
+    // Also unset inbounds when nusw is unset.
+    if (B)
+      SubclassOptionalData |= HasNoUnsignedSignedWrap;
+    else
+      SubclassOptionalData &= ~(IsInBounds | HasNoUnsignedSignedWrap);
+  }
+
+  void setHasNoUnsignedWrap(bool B) {
     SubclassOptionalData =
-      (SubclassOptionalData & ~IsInBounds) | (B * IsInBounds);
+        (SubclassOptionalData & ~HasNoUnsignedWrap) | (B * HasNoUnsignedWrap);
   }
 
 public:
@@ -421,6 +437,14 @@ class GEPOperator
     return SubclassOptionalData & IsInBounds;
   }
 
+  bool hasNoUnsignedSignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedSignedWrap;
+  }
+
+  bool hasNoUnsignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedWrap;
+  }
+
   /// Returns the offset of the index with an inrange attachment, or
   /// std::nullopt if none.
   std::optional<ConstantRange> getInRange() const;
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..20a1bd29577124 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -566,6 +566,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(fast);
   KEYWORD(nuw);
   KEYWORD(nsw);
+  KEYWORD(nusw);
   KEYWORD(exact);
   KEYWORD(disjoint);
   KEYWORD(inbounds);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2902bd9fe17c48..976e19479396d9 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -8340,7 +8340,17 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
   Value *Val = nullptr;
   LocTy Loc, EltLoc;
 
-  bool InBounds = EatIfPresent(lltok::kw_inbounds);
+  bool InBounds = false, NUSW = false, NUW = false;
+  while (true) {
+    if (EatIfPresent(lltok::kw_inbounds))
+      InBounds = true;
+    else if (EatIfPresent(lltok::kw_nusw))
+      NUSW = true;
+    else if (EatIfPresent(lltok::kw_nuw))
+      NUW = true;
+    else
+      break;
+  }
 
   Type *Ty = nullptr;
   if (parseType(Ty) ||
@@ -8393,9 +8403,14 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
 
   if (!GetElementPtrInst::getIndexedType(Ty, Indices))
     return error(Loc, "invalid getelementptr indices");
-  Inst = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  GetElementPtrInst *GEP = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  Inst = GEP;
   if (InBounds)
-    cast<GetElementPtrInst>(Inst)->setIsInBounds(true);
+    GEP->setIsInBounds(true);
+  if (NUSW)
+    GEP->setHasNoUnsignedSignedWrap(true);
+  if (NUW)
+    GEP->setHasNoUnsignedWrap(true);
   return AteExtraComma ? InstExtraComma : InstNormal;
 }
 
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index a0779f955cf28d..099f594d83b4da 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -5061,10 +5061,17 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
 
       unsigned TyID;
       Type *Ty;
-      bool InBounds;
+      bool InBounds = false, NUSW = false, NUW = false;
 
       if (BitCode == bitc::FUNC_CODE_INST_GEP) {
-        InBounds = Record[OpNum++];
+        uint64_t Flags = Record[OpNum++];
+        if (Flags & (1 << bitc::GEP_INBOUNDS))
+          InBounds = true;
+        if (Flags & (1 << bitc::GEP_NUSW))
+          NUSW = true;
+        if (Flags & (1 << bitc::GEP_NUW))
+          NUW = true;
+
         TyID = Record[OpNum++];
         Ty = getTypeByID(TyID);
       } else {
@@ -5095,7 +5102,8 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
         GEPIdx.push_back(Op);
       }
 
-      I = GetElementPtrInst::Create(Ty, BasePtr, GEPIdx);
+      auto *GEP = GetElementPtrInst::Create(Ty, BasePtr, GEPIdx);
+      I = GEP;
 
       ResTypeID = TyID;
       if (cast<GEPOperator>(I)->getNumIndices() != 0) {
@@ -5122,7 +5130,11 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
 
       InstructionList.push_back(I);
       if (InBounds)
-        cast<GetElementPtrInst>(I)->setIsInBounds(true);
+        GEP->setIsInBounds(true);
+      if (NUSW)
+        GEP->setHasNoUnsignedSignedWrap(true);
+      if (NUW)
+        GEP->setHasNoUnsignedWrap(true);
       break;
     }
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 1aaf160e91ca18..2627ce02286105 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2961,7 +2961,14 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
     Code = bitc::FUNC_CODE_INST_GEP;
     AbbrevToUse = FUNCTION_INST_GEP_ABBREV;
     auto &GEPInst = cast<GetElementPtrInst>(I);
-    Vals.push_back(GEPInst.isInBounds());
+    uint64_t Flags = 0;
+    if (GEPInst.isInBounds())
+      Flags |= 1 << bitc::GEP_INBOUNDS;
+    if (GEPInst.hasNoUnsignedSignedWrap())
+      Flags |= 1 << bitc::GEP_NUSW;
+    if (GEPInst.hasNoUnsignedWrap())
+      Flags |= 1 << bitc::GEP_NUW;
+    Vals.push_back(Flags);
     Vals.push_back(VE.getTypeID(GEPInst.getSourceElementType()));
     for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i)
       pushValueAndType(I.getOperand(i), InstID, Vals);
@@ -3859,7 +3866,7 @@ void ModuleBitcodeWriter::writeBlockInfo() {
   {
     auto Abbv = std::make_shared<BitCodeAbbrev>();
     Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_GEP));
-    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1));
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 3));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, // dest ty
                               Log2_32_Ceil(VE.getTypes().size() + 1)));
     Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Array));
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 941f6a7a7d8232..ced5d78f994ab5 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1417,6 +1417,10 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
   } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
     if (GEP->isInBounds())
       Out << " inbounds";
+    else if (GEP->hasNoUnsignedSignedWrap())
+      Out << " nusw";
+    if (GEP->hasNoUnsignedWrap())
+      Out << " nuw";
     if (auto InRange = GEP->getInRange()) {
       Out << " inrange(" << InRange->getLower() << ", " << InRange->getUpper()
           << ")";
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 678edc58ad848d..0fc2b093cf78ac 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -442,6 +442,8 @@ void Instruction::dropPoisonGeneratingFlags() {
 
   case Instruction::GetElementPtr:
     cast<GetElementPtrInst>(this)->setIsInBounds(false);
+    cast<GetElementPtrInst>(this)->setHasNoUnsignedSignedWrap(false);
+    cast<GetElementPtrInst>(this)->setHasNoUnsignedWrap(false);
     break;
 
   case Instruction::UIToFP:
@@ -658,9 +660,15 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
     if (isa<FPMathOperator>(this))
       copyFastMathFlags(FP->getFastMathFlags());
 
-  if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
-    if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
+  if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V)) {
+    if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this)) {
       DestGEP->setIsInBounds(SrcGEP->isInBounds() || DestGEP->isInBounds());
+      DestGEP->setHasNoUnsignedSignedWrap(SrcGEP->hasNoUnsignedSignedWrap() ||
+                                          DestGEP->hasNoUnsignedSignedWrap());
+      DestGEP->setHasNoUnsignedWrap(SrcGEP->hasNoUnsignedWrap() ||
+                                    DestGEP->hasNoUnsignedWrap());
+    }
+  }
 
   if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
     if (isa<PossiblyNonNegInst>(this))
@@ -698,9 +706,15 @@ void Instruction::andIRFlags(const Value *V) {
     }
   }
 
-  if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
-    if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
+  if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V)) {
+    if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this)) {
       DestGEP->setIsInBounds(SrcGEP->isInBounds() && DestGEP->isInBounds());
+      DestGEP->setHasNoUnsignedSignedWrap(SrcGEP->hasNoUnsignedSignedWrap() &&
+                                          DestGEP->hasNoUnsignedSignedWrap());
+      DestGEP->setHasNoUnsignedWrap(SrcGEP->hasNoUnsignedWrap() &&
+                                    DestGEP->hasNoUnsignedWrap());
+    }
+  }
 
   if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
     if (isa<PossiblyNonNegInst>(this))
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 7ad1ad4cddb703..bb1be7c2758b1f 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -2034,10 +2034,26 @@ void GetElementPtrInst::setIsInBounds(bool B) {
   cast<GEPOperator>(this)->setIsInBounds(B);
 }
 
+void GetElementPtrInst::setHasNoUnsignedSignedWrap(bool B) {
+  cast<GEPOperator>(this)->setHasNoUnsignedSignedWrap(B);
+}
+
+void GetElementPtrInst::setHasNoUnsignedWrap(bool B) {
+  cast<GEPOperator>(this)->setHasNoUnsignedWrap(B);
+}
+
 bool GetElementPtrInst::isInBounds() const {
   return cast<GEPOperator>(this)->isInBounds();
 }
 
+bool GetElementPtrInst::hasNoUnsignedSignedWrap() const {
+  return cast<GEPOperator>(this)->hasNoUnsignedSignedWrap();
+}
+
+bool GetElementPtrInst::hasNoUnsignedWrap() const {
+  return cast<GEPOperator>(this)->hasNoUnsignedWrap();
+}
+
 bool GetElementPtrInst::accumulateConstantOffset(const DataLayout &DL,
                                                  APInt &Offset) const {
   // Delegate to the generic GEPOperator implementation.
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index 29620ef716f25f..3de0d06dccfbac 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -42,7 +42,8 @@ bool Operator::hasPoisonGeneratingFlags() const {
   case Instruction::GetElementPtr: {
     auto *GEP = cast<GEPOperator>(this);
     // Note: inrange exists on constexpr only
-    return GEP->isInBounds() || GEP->getInRange() != std::nullopt;
+    return GEP->isInBounds() || GEP->hasNoUnsignedSignedWrap() ||
+           GEP->hasNoUnsignedWrap() || GEP->getInRange() != std::nullopt;
   }
   case Instruction::UIToFP:
   case Instruction::ZExt:
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index c54a956fc7e243..502eb56c867eb3 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -1019,12 +1019,17 @@ bool SeparateConstOffsetFromGEP::reorderGEP(GetElementPtrInst *GEP,
   }
 
   // For trivial GEP chains, we can swap the indicies.
+  // TODO(gep_nowrap): Make nusw preservation independent of inbounds and
+  // preserve nuw.
   auto NewSrc = Builder.CreateGEP(PtrGEPType, PtrGEP->getPointerOperand(),
                                   SmallVector<Value *, 4>(GEP->indices()));
   cast<GetElementPtrInst>(NewSrc)->setIsInBounds(IsChainInBounds);
+  cast<GetElementPtrInst>(NewSrc)->setHasNoUnsignedSignedWrap(IsChainInBounds);
+  cast<GetElementPtrInst>(NewSrc)->setHasNoUnsignedWrap(false);
   auto NewGEP = Builder.CreateGEP(GEPType, NewSrc,
                                   SmallVector<Value *, 4>(PtrGEP->indices()));
   cast<GetElementPtrInst>(NewGEP)->setIsInBounds(IsChainInBounds);
+  cast<GetElementPtrInst>(NewSrc)->setHasNoUnsignedWrap(false);
   GEP->replaceAllUsesWith(NewGEP);
   RecursivelyDeleteTriviallyDeadInstructions(GEP);
   return true;
@@ -1121,6 +1126,9 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
   // possible. GEPs with inbounds are more friendly to alias analysis.
   bool GEPWasInBounds = GEP->isInBounds();
   GEP->setIsInBounds(false);
+  // TODO(gep_nowrap): Try to preserve these.
+  GEP->setHasNoUnsignedSignedWrap(false);
+  GEP->setHasNoUnsignedWrap(false);
 
   // Lowers a GEP to either GEPs with a single index or arithmetic operations.
   if (LowerGEP) {
@@ -1396,6 +1404,11 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First,
      Offset.ugt(ObjectSize)) {
     First->setIsInBounds(false);
     Second->setIsInBounds(false);
+    // TODO(gep_nowrap): Make flag preservation more precise.
+    First->setHasNoUnsignedSignedWrap(false);
+    Second->setHasNoUnsignedSignedWrap(false);
+    First->setHasNoUnsignedWrap(false);
+    Second->setHasNoUnsignedWrap(false);
   } else
     First->setIsInBounds(true);
 }
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index d95248c84b8602..fa80246b20f81b 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -438,6 +438,12 @@ int FunctionComparator::cmpConstants(const Constant *L,
         return Res;
       if (int Res = cmpNumbers(GEPL->isInBounds(), GEPR->isInBounds()))
         return Res;
+      if (int Res = cmpNumbers(GEPL->hasNoUnsignedSignedWrap(),
+                               GEPR->hasNoUnsignedSignedWrap()))
+        return Res;
+      if (int Res =
+              cmpNumbers(GEPL->hasNoUnsignedWrap(), GEPR->hasNoUnsignedWrap()))
+        return Res;
 
       std::optional<ConstantRange> InRangeL = GEPL->getInRange();
       std::optional<ConstantRange> InRangeR = GEPR->getInRange();
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 71594be2b965aa..4ee99c5edb6f43 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1092,7 +1092,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       I->setIsExact(ExactFlags.IsExact);
       break;
     case OperationType::GEPOp:
+      // TODO(gep_nowrap): Track nusw and nuw flags in VPlan. For now assume
+      // they need to be dropped.
       cast<GetElementPtrInst>(I)->setIsInBounds(GEPFlags.IsInBounds);
+      cast<GetElementPtrInst>(I)->setHasNoUnsignedSignedWrap(
+          GEPFlags.IsInBounds);
+      cast<GetElementPtrInst>(I)->setHasNoUnsignedWrap(false);
       break;
     case OperationType::FPMathOp:
       I->setHasAllowReassoc(FMFs.AllowReassoc);
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index e0ad8bf000be15..40af7e56b329b3 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -316,3 +316,46 @@ define <2 x i32> @test_trunc_both_reversed_vector(<2 x i64> %a) {
   %res = trunc nsw nuw <2 x i64> %a to <2 x i32>
   ret <2 x i32> %res
 }
+
+define ptr @gep_nuw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr nuw i8, ptr %p, i64 %idx
+  %gep = getelementptr nuw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+define ptr @gep_inbounds_nuw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr inbounds nuw i8, ptr %p, i64 %idx
+  %gep = getelementptr inbounds nuw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+define ptr @gep_nusw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr nusw i8, ptr %p, i64 %idx
+  %gep = getelementptr nusw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+; inbounds implies nusw, so the flag is not printed back.
+define ptr @gep_inbounds_nusw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr inbounds i8, ptr %p, i64 %idx
+  %gep = getelementptr inbounds nusw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+define ptr @gep_nusw_nuw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr nusw nuw i8, ptr %p, i64 %idx
+  %gep = getelementptr nusw nuw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+define ptr @gep_inbounds_nusw_nuw(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr inbounds nuw i8, ptr %p, i64 %idx
+  %gep = getelementptr inbounds nusw nuw i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
+
+define ptr @gep_nuw_nusw_inbounds(ptr %p, i64 %idx) {
+; CHECK: %gep = getelementptr inbounds nuw i8, ptr %p, i64 %idx
+  %gep = getelementptr nuw nusw inbounds i8, ptr %p, i64 %idx
+  ret ptr %gep
+}
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index 391d626a795c7d..5fedb1f8575035 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1160,6 +1160,28 @@ define i32 @propagate_drop_flags_trunc(i64 %arg) {
   ret i32 %v1.fr
 }
 
+define ptr @propagate_drop_flags_gep_nusw(ptr %p) {
+; CHECK-LABEL: @propagate_drop_flags_gep_nusw(
+; CHECK-NEXT:    [[P_FR:%.*]] = freeze ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[P_FR]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP]]
+;
+  %gep = getelementptr nusw i8, ptr %p, i64 1
+  %gep.fr = freeze ptr %gep
+  ret ptr %gep.fr
+}
+
+define ptr @propagate_drop_flags_gep_nuw(ptr %p) {
+; CHECK-LABEL: @propagate_drop_flags_gep_nuw(
+; CHECK-NEXT:    [[P_FR:%.*]] = freeze ptr [[P:%.*]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[P_FR]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP]]
+;
+  %gep = getelementptr nuw i8, ptr %p, i64 1
+  %gep.fr = freeze ptr %gep
+  ret ptr %gep.fr
+}
+
 declare i32 @llvm.umax.i32(i32 %a, i32 %b)
 
 define i32 @freeze_call_with_range_attr(i32 %a) {
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 887d1820168181..e6a255a4b8f086 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -215,3 +215,63 @@ F:
   %z2 = trunc nsw nuw i32 %x to i16
   ret i16 %z2
 }
+
+define ptr @hoist_gep_flags_both_nuw(i1 %C, ptr %p) {
+; CHECK-LABEL: @hoist_gep_flags_both_nuw(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr nuw i8, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %gep1 = getelementptr nuw i8, ptr %p, i64 1
+  ret ptr %gep1
+F:
+  %gep2 = getelementptr nuw i8, ptr %p, i64 1
+  ret ptr %gep2
+}
+
+define ptr @hoist_gep_flags_both_nusw(i1 %C, ptr %p) {
+; CHECK-LABEL: @hoist_gep_flags_both_nusw(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr nusw i8, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %gep1 = getelementptr nusw i8, ptr %p, i64 1
+  ret ptr %gep1
+F:
+  %gep2 = getelementptr nusw i8, ptr %p, i64 1
+  ret ptr %gep2
+}
+
+define ptr @hoist_gep_flags_intersect1(i1 %C, ptr %p) {
+; CHECK-LABEL: @hoist_gep_flags_intersect1(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr nusw i8, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %gep1 = getelementptr inbounds nuw i8, ptr %p, i64 1
+  ret ptr %gep1
+F:
+  %gep2 = getelementptr nusw i8, ptr %p, i64 1
+  ret ptr %gep2
+}
+
+define ptr @hoist_gep_flags_intersect2(i1 %C, ptr %p) {
+; CHECK-LABEL: @hoist_gep_flags_intersect2(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr i8, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    ret ptr [[GEP1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %gep1 = getelementptr inbounds i8, ptr %p, i64 1
+  ret ptr %gep1
+F:
+  %gep2 = getelementptr nuw i8, ptr %p, i64 1
+  ret ptr %gep2
+}
diff --git a/llvm/test/tools/llvm-reduce/reduce-flags.ll b/llvm/test/tools/llvm-reduce/reduce-flags.ll
index 5d6d1260ac50e0..293504e32f9108 100644
--- a/llvm/test/tools/llvm-reduce/reduce-flags.ll
+++ b/llvm/test/tools/llvm-reduce/reduce-flags.ll
@@ -57,18 +57,26 @@ define i32 @ashr_exact_keep(i32 %a, i32 %b) {
   ret i32 %op
 }
 
-; CHECK-LABEL: @getelementptr_inbounds_drop(
+; CHECK-LABEL: @getelementptr_inbounds_nuw_drop_both(
 ; INTERESTING: getelementptr
 ; RESULT: getelementptr i32, ptr %a, i64 %b
-define ptr @getelementptr_inbounds_drop(ptr %a, i64 %b) {
-  %op = getelementptr inbounds i32, ptr %a, i64 %b
+define ptr @getelementptr_inbounds_nuw_drop_both(ptr %a, i64 %b) {
+  %op = getelementptr inbounds nuw i32, ptr %a, i64 %b
   ret ptr %op
 }
 
-; CHECK-LABEL: @getelementptr_inbounds_keep(
+; CHECK-LABEL: @getelementptr_inbounds_keep_only_inbounds(
 ; INTERESTING: inbounds
 ; RESULT: getelementptr inbounds i32, ptr %a, i64 %b
-define ptr @getelementptr_inbounds_keep(ptr %a, i64 %b) {
+define ptr @getelementptr_inbounds_keep_only_inbounds(ptr %a, i64 %b) {
+  %op = getelementptr inbounds nuw i32, ptr %a, i64 %b
+  ret ptr %op
+}
+
+; CHECK-LABEL: @getelementptr_inbounds_relax_to_nusw(
+; INTERESTING: getelementptr {{inbounds|nusw}}
+; RESULT: getelementptr nusw i32, ptr %a, i64 %b
+define ptr @getelementptr_inbounds_relax_to_nusw(ptr %a, i64 %b) {
   %op = getelementptr inbounds i32, ptr %a, i64 %b
   ret ptr %op
 }
diff --git a/llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp b/llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp
index ad619a6c02a4d2..978a58e397c4c3 100644
--- a/llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp
+++ b/llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp
@@ -44,6 +44,10 @@ static void reduceFlagsInModule(Oracle &O, ReducerWorkItem &WorkItem) {
       } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
         if (GEP->isInBounds() && !O.shouldKeep())
           GEP->setIsInBounds(false);
+        if (GEP->hasNoUnsignedSignedWrap() && !O.shouldKeep())
+          GEP->setHasNoUnsignedSignedWrap(false);
+        if (GEP->hasNoUnsignedWrap() && !O.shouldKeep())
+          GEP->setHasNoUnsignedWrap(false);
       } else if (auto *FPOp = dyn_cast<FPMathOperator>(&I)) {
         FastMathFlags Flags = FPOp->getFastMathFlags();
 

>From e6d87106fd677c98422efa97dffef600b053853c Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 2 May 2024 14:15:35 +0900
Subject: [PATCH 2/2] Support flags on GEP constant expressions as well

---
 clang/lib/CodeGen/CGVTT.cpp                   |  4 ++-
 clang/lib/CodeGen/ItaniumCXXABI.cpp           |  4 ++-
 llvm/docs/LangRef.rst                         |  4 +++
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  3 +-
 llvm/include/llvm/IR/Constants.h              | 11 +++---
 llvm/lib/Analysis/ConstantFolding.cpp         | 12 +++++--
 llvm/lib/AsmParser/LLParser.cpp               | 18 +++++++---
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     | 35 +++++++++++-------
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     | 21 ++++++-----
 llvm/lib/IR/ConstantFold.cpp                  |  5 ++-
 llvm/lib/IR/Constants.cpp                     | 13 +++++--
 .../AMDGPU/AMDGPULowerBufferFatPointers.cpp   |  3 +-
 llvm/test/Assembler/flags.ll                  | 36 +++++++++++++++++++
 13 files changed, 127 insertions(+), 42 deletions(-)

diff --git a/clang/lib/CodeGen/CGVTT.cpp b/clang/lib/CodeGen/CGVTT.cpp
index d2376b14dd5826..8c72f3dccfd6e3 100644
--- a/clang/lib/CodeGen/CGVTT.cpp
+++ b/clang/lib/CodeGen/CGVTT.cpp
@@ -87,8 +87,10 @@ CodeGenVTables::EmitVTTDefinition(llvm::GlobalVariable *VTT,
      unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
      llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                                  llvm::APInt(32, VTableSize - Offset, true));
+     // TODO(gep_nowrap): Set nuw as well.
      llvm::Constant *Init = llvm::ConstantExpr::getGetElementPtr(
-         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, InRange);
+         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, /*NUSW=*/true,
+         /*NUW=*/false, InRange);
 
      VTTComponents.push_back(Init);
   }
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 18acf7784f714b..0138915ad35996 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -1901,8 +1901,10 @@ ItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
   unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
   llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                               llvm::APInt(32, VTableSize - Offset, true));
+  // TODO(gep_nowrap): Set nuw as well.
   return llvm::ConstantExpr::getGetElementPtr(
-      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, InRange);
+      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, /*NUSW=*/true,
+      /*NUW=*/false, InRange);
 }
 
 // Check whether all the non-inline virtual methods for the class have the
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 7aeed82ab84df7..a4340f060d6f07 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11347,6 +11347,10 @@ These rules are based on the assumption that no allocated object may cross
 the unsigned address space boundary, and no allocated object may be larger
 than half the pointer index type space.
 
+If ``inbounds`` is present on a ``getelementptr`` instruction, the ``nusw``
+attribute will be automatically set as well. For this reason, the ``nusw``
+will also not be printed in textual IR if ``inbounds`` is already present.
+
 If the ``inrange(Start, End)`` attribute is present, loading from or
 storing to any pointer derived from the ``getelementptr`` has undefined
 behavior if the load or store would access memory outside the half-open range
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 1fce358a92e548..d3b9e96520f88a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -385,7 +385,7 @@ enum ConstantsCodes {
   CST_CODE_CSTRING = 9,          // CSTRING:       [values]
   CST_CODE_CE_BINOP = 10,        // CE_BINOP:      [opcode, opval, opval]
   CST_CODE_CE_CAST = 11,         // CE_CAST:       [opcode, opty, opval]
-  CST_CODE_CE_GEP = 12,          // CE_GEP:        [n x operands]
+  CST_CODE_CE_GEP_OLD = 12,      // CE_GEP:        [n x operands]
   CST_CODE_CE_SELECT = 13,       // CE_SELECT:     [opval, opval, opval]
   CST_CODE_CE_EXTRACTELT = 14,   // CE_EXTRACTELT: [opty, opval, opval]
   CST_CODE_CE_INSERTELT = 15,    // CE_INSERTELT:  [opval, opval, opval]
@@ -412,6 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
+  CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 9ec81903f09c96..28ee766a6843e5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1198,26 +1198,27 @@ class ConstantExpr : public Constant {
   /// \param OnlyIfReducedTy see \a getWithOperands() docs.
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Constant *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     return getGetElementPtr(
         Ty, C, ArrayRef((Value *const *)IdxList.data(), IdxList.size()),
-        InBounds, InRange, OnlyIfReducedTy);
+        InBounds, NUSW, NUW, InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, Constant *Idx, bool InBounds = false,
+                   bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     // This form of the function only exists to avoid ambiguous overload
     // warnings about whether to convert Idx to ArrayRef<Constant *> or
     // ArrayRef<Value *>.
-    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, InRange,
-                            OnlyIfReducedTy);
+    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, NUSW, NUW,
+                            InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Value *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr);
 
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48af..1cbcb6868eeef9 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -856,8 +856,10 @@ Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops,
   if (!Any)
     return nullptr;
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW here.
   Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], NewIdxs,
-                                               InBounds, InRange);
+                                               InBounds, /*NUSW=*/InBounds,
+                                               /*NUW=*/false, InRange);
   return ConstantFoldConstant(C, DL, TLI);
 }
 
@@ -980,7 +982,9 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
     NewIdxs.push_back(ConstantInt::get(
         Type::getIntNTy(Ptr->getContext(), Index.getBitWidth()), Index));
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW.
   return ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, InBounds,
+                                        /*NUSW=*/InBounds, /*NUW=*/false,
                                         InRange);
 }
 
@@ -1028,8 +1032,10 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode,
     if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI))
       return C;
 
-    return ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], Ops.slice(1),
-                                          GEP->isInBounds(), GEP->getInRange());
+    return ConstantExpr::getGetElementPtr(
+        SrcElemTy, Ops[0], Ops.slice(1), GEP->isInBounds(),
+        GEP->hasNoUnsignedSignedWrap(), GEP->hasNoUnsignedWrap(),
+        GEP->getInRange());
   }
 
   if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) {
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 976e19479396d9..fa4d87ca8d5ffe 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4216,7 +4216,7 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   case lltok::kw_extractelement: {
     unsigned Opc = Lex.getUIntVal();
     SmallVector<Constant*, 16> Elts;
-    bool InBounds = false;
+    bool InBounds = false, HasNUSW = false, HasNUW = false;
     bool HasInRange = false;
     APSInt InRangeStart;
     APSInt InRangeEnd;
@@ -4224,7 +4224,17 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     Lex.Lex();
 
     if (Opc == Instruction::GetElementPtr) {
-      InBounds = EatIfPresent(lltok::kw_inbounds);
+      while (true) {
+        if (EatIfPresent(lltok::kw_inbounds))
+          InBounds = true;
+        else if (EatIfPresent(lltok::kw_nusw))
+          HasNUSW = true;
+        else if (EatIfPresent(lltok::kw_nuw))
+          HasNUW = true;
+        else
+          break;
+      }
+
       if (EatIfPresent(lltok::kw_inrange)) {
         if (parseToken(lltok::lparen, "expected '('"))
           return true;
@@ -4303,8 +4313,8 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       if (!GetElementPtrInst::getIndexedType(Ty, Indices))
         return error(ID.Loc, "invalid getelementptr indices");
 
-      ID.ConstantVal = ConstantExpr::getGetElementPtr(Ty, Elts[0], Indices,
-                                                      InBounds, InRange);
+      ID.ConstantVal = ConstantExpr::getGetElementPtr(
+          Ty, Elts[0], Indices, InBounds, HasNUSW, HasNUW, InRange);
     } else if (Opc == Instruction::ShuffleVector) {
       if (Elts.size() != 3)
         return error(ID.Loc, "expected three operands to shufflevector");
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 099f594d83b4da..278d2c6adae6a7 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1613,9 +1613,11 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           C = ConstantExpr::getCompare(BC->Flags, ConstOps[0], ConstOps[1]);
           break;
         case Instruction::GetElementPtr:
-          C = ConstantExpr::getGetElementPtr(BC->SrcElemTy, ConstOps[0],
-                                             ArrayRef(ConstOps).drop_front(),
-                                             BC->Flags, BC->getInRange());
+          C = ConstantExpr::getGetElementPtr(
+              BC->SrcElemTy, ConstOps[0], ArrayRef(ConstOps).drop_front(),
+              (BC->Flags & (1 << bitc::GEP_INBOUNDS)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUSW)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUW)) != 0, BC->getInRange());
           break;
         case Instruction::ExtractElement:
           C = ConstantExpr::getExtractElement(ConstOps[0], ConstOps[1]);
@@ -1699,8 +1701,12 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
         I = GetElementPtrInst::Create(BC->SrcElemTy, Ops[0],
                                       ArrayRef(Ops).drop_front(), "constexpr",
                                       InsertBB);
-        if (BC->Flags)
+        if (BC->Flags & (1 << bitc::GEP_INBOUNDS))
           cast<GetElementPtrInst>(I)->setIsInBounds();
+        if (BC->Flags & (1 << bitc::GEP_NUSW))
+          cast<GetElementPtrInst>(I)->setHasNoUnsignedSignedWrap();
+        if (BC->Flags & (1 << bitc::GEP_NUW))
+          cast<GetElementPtrInst>(I)->setHasNoUnsignedWrap();
         break;
       case Instruction::Select:
         I = SelectInst::Create(Ops[0], Ops[1], Ops[2], "constexpr", InsertBB);
@@ -3320,9 +3326,10 @@ Error BitcodeReader::parseConstants() {
       break;
     }
     case bitc::CST_CODE_CE_INBOUNDS_GEP: // [ty, n x operands]
-    case bitc::CST_CODE_CE_GEP: // [ty, n x operands]
+    case bitc::CST_CODE_CE_GEP_OLD:      // [ty, n x operands]
     case bitc::CST_CODE_CE_GEP_WITH_INRANGE_INDEX_OLD: // [ty, flags, n x
                                                        // operands]
+    case bitc::CST_CODE_CE_GEP:                // [ty, flags, n x operands]
     case bitc::CST_CODE_CE_GEP_WITH_INRANGE: { // [ty, flags, start, end, n x
                                                // operands]
       if (Record.size() < 2)
@@ -3330,27 +3337,30 @@ Error BitcodeReader::parseConstants() {
       unsigned OpNum = 0;
       Type *PointeeType = nullptr;
       if (BitCode == bitc::CST_CODE_CE_GEP_WITH_INRANGE_INDEX_OLD ||
-          BitCode == bitc::CST_CODE_CE_GEP_WITH_INRANGE || Record.size() % 2)
+          BitCode == bitc::CST_CODE_CE_GEP_WITH_INRANGE ||
+          BitCode == bitc::CST_CODE_CE_GEP || Record.size() % 2)
         PointeeType = getTypeByID(Record[OpNum++]);
 
-      bool InBounds = false;
+
+      uint64_t Flags = 0;
       std::optional<ConstantRange> InRange;
       if (BitCode == bitc::CST_CODE_CE_GEP_WITH_INRANGE_INDEX_OLD) {
         uint64_t Op = Record[OpNum++];
-        InBounds = Op & 1;
+        Flags = Op & 1; // inbounds
         unsigned InRangeIndex = Op >> 1;
         // "Upgrade" inrange by dropping it. The feature is too niche to
         // bother.
         (void)InRangeIndex;
       } else if (BitCode == bitc::CST_CODE_CE_GEP_WITH_INRANGE) {
-        uint64_t Op = Record[OpNum++];
-        InBounds = Op & 1;
+        Flags = Record[OpNum++];
         Expected<ConstantRange> MaybeInRange = readConstantRange(Record, OpNum);
         if (!MaybeInRange)
           return MaybeInRange.takeError();
         InRange = MaybeInRange.get();
+      } else if (BitCode == bitc::CST_CODE_CE_GEP) {
+        Flags = Record[OpNum++];
       } else if (BitCode == bitc::CST_CODE_CE_INBOUNDS_GEP)
-        InBounds = true;
+        Flags = (1 << bitc::GEP_INBOUNDS);
 
       SmallVector<unsigned, 16> Elts;
       unsigned BaseTypeID = Record[OpNum];
@@ -3383,7 +3393,8 @@ Error BitcodeReader::parseConstants() {
 
       V = BitcodeConstant::create(
           Alloc, CurTy,
-          {Instruction::GetElementPtr, InBounds, PointeeType, InRange}, Elts);
+          {Instruction::GetElementPtr, uint8_t(Flags), PointeeType, InRange},
+          Elts);
       break;
     }
     case bitc::CST_CODE_CE_SELECT: {  // CE_SELECT: [opval#, opval#, opval#]
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 2627ce02286105..d1a423d0d84dcc 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1656,6 +1656,13 @@ static uint64_t getOptimizationFlags(const Value *V) {
       Flags |= 1 << bitc::TIO_NO_SIGNED_WRAP;
     if (TI->hasNoUnsignedWrap())
       Flags |= 1 << bitc::TIO_NO_UNSIGNED_WRAP;
+  } else if (const auto *GEP = dyn_cast<GEPOperator>(V)) {
+    if (GEP->isInBounds())
+      Flags |= 1 << bitc::GEP_INBOUNDS;
+    if (GEP->hasNoUnsignedSignedWrap())
+      Flags |= 1 << bitc::GEP_NUSW;
+    if (GEP->hasNoUnsignedWrap())
+      Flags |= 1 << bitc::GEP_NUW;
   }
 
   return Flags;
@@ -2767,12 +2774,11 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
         Code = bitc::CST_CODE_CE_GEP;
         const auto *GO = cast<GEPOperator>(C);
         Record.push_back(VE.getTypeID(GO->getSourceElementType()));
+        Record.push_back(getOptimizationFlags(GO));
         if (std::optional<ConstantRange> Range = GO->getInRange()) {
           Code = bitc::CST_CODE_CE_GEP_WITH_INRANGE;
-          Record.push_back(GO->isInBounds());
           emitConstantRange(Record, *Range);
-        } else if (GO->isInBounds())
-          Code = bitc::CST_CODE_CE_INBOUNDS_GEP;
+        }
         for (unsigned i = 0, e = CE->getNumOperands(); i != e; ++i) {
           Record.push_back(VE.getTypeID(C->getOperand(i)->getType()));
           Record.push_back(VE.getValueID(C->getOperand(i)));
@@ -2961,14 +2967,7 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
     Code = bitc::FUNC_CODE_INST_GEP;
     AbbrevToUse = FUNCTION_INST_GEP_ABBREV;
     auto &GEPInst = cast<GetElementPtrInst>(I);
-    uint64_t Flags = 0;
-    if (GEPInst.isInBounds())
-      Flags |= 1 << bitc::GEP_INBOUNDS;
-    if (GEPInst.hasNoUnsignedSignedWrap())
-      Flags |= 1 << bitc::GEP_NUSW;
-    if (GEPInst.hasNoUnsignedWrap())
-      Flags |= 1 << bitc::GEP_NUW;
-    Vals.push_back(Flags);
+    Vals.push_back(getOptimizationFlags(&I));
     Vals.push_back(VE.getTypeID(GEPInst.getSourceElementType()));
     for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i)
       pushValueAndType(I.getOperand(i), InstID, Vals);
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index a766b1fe601823..941c2987aa6d58 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1711,6 +1711,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
     for (unsigned i = 0, e = Idxs.size(); i != e; ++i)
       if (!NewIdxs[i]) NewIdxs[i] = cast<Constant>(Idxs[i]);
     return ConstantExpr::getGetElementPtr(PointeeTy, C, NewIdxs, InBounds,
+                                          /*NUSW=*/InBounds, /*NUW=*/false,
                                           InRange);
   }
 
@@ -1720,8 +1721,10 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
     if (auto *GV = dyn_cast<GlobalVariable>(C))
       if (!GV->hasExternalWeakLinkage() && GV->getValueType() == PointeeTy &&
           isInBoundsIndices(Idxs))
+        // TODO(gep_nowrap): Can also set NUW here.
         return ConstantExpr::getGetElementPtr(PointeeTy, C, Idxs,
-                                              /*InBounds=*/true, InRange);
+                                              /*InBounds=*/true, /*NUSW=*/true,
+                                              /*NUW=*/false, InRange);
 
   return nullptr;
 }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5268eccf701442..eee69838bcc2ea 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1568,7 +1568,8 @@ Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty,
     assert(SrcTy || (Ops[0]->getType() == getOperand(0)->getType()));
     return ConstantExpr::getGetElementPtr(
         SrcTy ? SrcTy : GEPO->getSourceElementType(), Ops[0], Ops.slice(1),
-        GEPO->isInBounds(), GEPO->getInRange(), OnlyIfReducedTy);
+        GEPO->isInBounds(), GEPO->hasNoUnsignedSignedWrap(),
+        GEPO->hasNoUnsignedWrap(), GEPO->getInRange(), OnlyIfReducedTy);
   }
   case Instruction::ICmp:
   case Instruction::FCmp:
@@ -2349,6 +2350,7 @@ Constant *ConstantExpr::getCompare(unsigned short Predicate, Constant *C1,
 
 Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
                                          ArrayRef<Value *> Idxs, bool InBounds,
+                                         bool NUSW, bool NUW,
                                          std::optional<ConstantRange> InRange,
                                          Type *OnlyIfReducedTy) {
   assert(Ty && "Must specify element type");
@@ -2390,7 +2392,14 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
     ArgVec.push_back(Idx);
   }
 
-  unsigned SubClassOptionalData = InBounds ? GEPOperator::IsInBounds : 0;
+  unsigned SubClassOptionalData = 0;
+  if (InBounds)
+    SubClassOptionalData |=
+        GEPOperator::IsInBounds | GEPOperator::HasNoUnsignedSignedWrap;
+  if (NUSW)
+    SubClassOptionalData |= GEPOperator::HasNoUnsignedSignedWrap;
+  if (NUW)
+    SubClassOptionalData |= GEPOperator::HasNoUnsignedWrap;
   const ConstantExprKeyType Key(Instruction::GetElementPtr, ArgVec, 0,
                                 SubClassOptionalData, std::nullopt, Ty,
                                 InRange);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
index 1114a8c40114e4..73e7ac407b327b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
@@ -801,7 +801,8 @@ Value *FatPtrConstMaterializer::materialize(Value *V) {
         Ops.push_back(cast<Constant>(U.get()));
       auto *NewGEP = ConstantExpr::getGetElementPtr(
           NewSrcTy, Ops[0], ArrayRef<Constant *>(Ops).slice(1),
-          GEPO->isInBounds(), GEPO->getInRange());
+          GEPO->isInBounds(), GEPO->hasNoUnsignedSignedWrap(),
+          GEPO->hasNoUnsignedWrap(), GEPO->getInRange());
       LLVM_DEBUG(dbgs() << "p7-getting GEP: " << *GEPO << " becomes " << *NewGEP
                         << "\n");
       Value *FurtherMap = materialize(NewGEP);
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 40af7e56b329b3..231d173b8d7d7e 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -359,3 +359,39 @@ define ptr @gep_nuw_nusw_inbounds(ptr %p, i64 %idx) {
   %gep = getelementptr nuw nusw inbounds i8, ptr %p, i64 %idx
   ret ptr %gep
 }
+
+define ptr @const_gep_nuw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr nuw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr nuw (i8, ptr @addr, i64 100)
+}
+
+define ptr @const_gep_inbounds_nuw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr inbounds nuw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr inbounds nuw (i8, ptr @addr, i64 100)
+}
+
+define ptr @const_gep_nusw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr nusw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr nusw (i8, ptr @addr, i64 100)
+}
+
+; inbounds implies nusw, so the flag is not printed back.
+define ptr @const_gep_inbounds_nusw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr inbounds (i8, ptr @addr, i64 100)
+  ret ptr getelementptr inbounds nusw (i8, ptr @addr, i64 100)
+}
+
+define ptr @const_gep_nusw_nuw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr nusw nuw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr nusw nuw (i8, ptr @addr, i64 100)
+}
+
+define ptr @const_gep_inbounds_nusw_nuw(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr inbounds nuw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr inbounds nusw nuw (i8, ptr @addr, i64 100)
+}
+
+define ptr @const_gep_nuw_nusw_inbounds(ptr %p, i64 %idx) {
+; CHECK: ret ptr getelementptr inbounds nuw (i8, ptr @addr, i64 100)
+  ret ptr getelementptr nuw nusw inbounds (i8, ptr @addr, i64 100)
+}



More information about the llvm-commits mailing list