[llvm] [SeparateConstOffsetFromGEP] Decompose constant xor operand if possible (PR #150438)

Sumanth Gundapaneni via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 24 11:01:07 PDT 2025


https://github.com/sgundapa updated https://github.com/llvm/llvm-project/pull/150438

>From 72844323dc672f2d077f169c0f3856e8f2401d96 Mon Sep 17 00:00:00 2001
From: Sumanth Gundapaneni <sugundap at amd.com>
Date: Wed, 23 Jul 2025 11:58:12 -0500
Subject: [PATCH 1/2] [SeparateConstOffsetFromGEP] Decompose constant xor
 operand if possible

Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part
of base for memory operations. This transformation is true under the
following conditions
Check 1 -  B and C are disjoint.
Check 2 - XOR(A,C) and B are disjoint.

This transformation can map these Xors in to better addressing mode and
eventually decompose them in to geps.
---
 .../Scalar/SeparateConstOffsetFromGEP.cpp     | 141 ++++++++++++++++--
 .../AMDGPU/xor-idiom.ll                       |  66 ++++++++
 2 files changed, 191 insertions(+), 16 deletions(-)
 create mode 100644 llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll

diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 320b79203c0b3..203850c28787c 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -238,16 +238,17 @@ class ConstantOffsetExtractor {
   /// \p PreservesNUW  Outputs whether the extraction allows preserving the
   ///                  GEP's nuw flag, if it has one.
   static Value *Extract(Value *Idx, GetElementPtrInst *GEP,
-                        User *&UserChainTail, bool &PreservesNUW);
+                        User *&UserChainTail, bool &PreservesNUW,
+                        DominatorTree *DT);
 
   /// Looks for a constant offset from the given GEP index without extracting
   /// it. It returns the numeric value of the extracted constant offset (0 if
   /// failed). The meaning of the arguments are the same as Extract.
-  static int64_t Find(Value *Idx, GetElementPtrInst *GEP);
+  static int64_t Find(Value *Idx, GetElementPtrInst *GEP, DominatorTree *DT);
 
 private:
-  ConstantOffsetExtractor(BasicBlock::iterator InsertionPt)
-      : IP(InsertionPt), DL(InsertionPt->getDataLayout()) {}
+  ConstantOffsetExtractor(BasicBlock::iterator InsertionPt, DominatorTree *DT)
+      : IP(InsertionPt), DT(DT), DL(InsertionPt->getDataLayout()) {}
 
   /// Searches the expression that computes V for a non-zero constant C s.t.
   /// V can be reassociated into the form V' + C. If the searching is
@@ -321,6 +322,20 @@ class ConstantOffsetExtractor {
   bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO,
                     bool NonNegative);
 
+  // Find the most dominating Xor with the same base operand.
+  BinaryOperator *findDominatingXor(Value *BaseOperand,
+                                    BinaryOperator *CurrentXor);
+
+  /// Check if Xor instruction should be considered for optimization.
+  bool shouldConsiderXor(BinaryOperator *XorInst);
+
+  /// Cache the information about Xor idiom.
+  struct XorRewriteInfo {
+    llvm::BinaryOperator *BaseXor = nullptr;
+    int64_t AdjustedOffset = 0;
+  };
+  std::optional<XorRewriteInfo> CachedXorInfo;
+
   /// The path from the constant offset to the old GEP index. e.g., if the GEP
   /// index is "a * b + (c + 5)". After running function find, UserChain[0] will
   /// be the constant 5, UserChain[1] will be the subexpression "c + 5", and
@@ -336,6 +351,8 @@ class ConstantOffsetExtractor {
   /// Insertion position of cloned instructions.
   BasicBlock::iterator IP;
 
+  DominatorTree *DT;
+
   const DataLayout &DL;
 };
 
@@ -514,12 +531,14 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
                                             bool ZeroExtended,
                                             BinaryOperator *BO,
                                             bool NonNegative) {
-  // We only consider ADD, SUB and OR, because a non-zero constant found in
+  // We only consider ADD, SUB, OR and XOR, because a non-zero constant found in
   // expressions composed of these operations can be easily hoisted as a
-  // constant offset by reassociation.
+  // constant offset by reassociation. XOR is a special case and can be folded
+  // in to gep if the constant is proven to be disjoint.
   if (BO->getOpcode() != Instruction::Add &&
       BO->getOpcode() != Instruction::Sub &&
-      BO->getOpcode() != Instruction::Or) {
+      BO->getOpcode() != Instruction::Or &&
+      BO->getOpcode() != Instruction::Xor) {
     return false;
   }
 
@@ -530,6 +549,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
       !cast<PossiblyDisjointInst>(BO)->isDisjoint())
     return false;
 
+  // Handle Xor idiom.
+  if (BO->getOpcode() == Instruction::Xor)
+    return shouldConsiderXor(BO);
+
   // FIXME: We don't currently support constants from the RHS of subs,
   // when we are zero-extended, because we need a way to zero-extended
   // them before they are negated.
@@ -740,6 +763,10 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
          "UserChain, so no one should be used more than "
          "once");
 
+  // Special case for Xor idiom.
+  if (BO->getOpcode() == Instruction::Xor)
+    return CachedXorInfo->BaseXor;
+
   unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1);
   assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]);
   Value *NextInChain = removeConstOffset(ChainIndex - 1);
@@ -780,6 +807,80 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
   return NewBO;
 }
 
+// Find the most dominating Xor with the same base operand.
+BinaryOperator *
+ConstantOffsetExtractor::findDominatingXor(Value *BaseOperand,
+                                           BinaryOperator *CurrentXor) {
+  BinaryOperator *MostDominatingXor = nullptr;
+  // Iterate over all instructions that use the BaseOperand.
+  for (User *U : BaseOperand->users()) {
+    auto *CandidateXor = dyn_cast<BinaryOperator>(U);
+
+    // Simple checks.
+    if (!CandidateXor || CandidateXor == CurrentXor)
+      continue;
+
+    // Check if the binary operator is a Xor with constant.
+    if (!match(CandidateXor, m_Xor(m_Specific(BaseOperand), m_ConstantInt())))
+      continue;
+
+    // After confirming the structure, check the dominance relationship.
+    if (DT->dominates(CandidateXor, CurrentXor))
+      // If we find a dominating Xor, keep it if it's the first one,
+      // or if it dominates the best candidate we've found so far.
+      if (!MostDominatingXor || DT->dominates(CandidateXor, MostDominatingXor))
+        MostDominatingXor = CandidateXor;
+  }
+
+  return MostDominatingXor;
+}
+
+// Check if Xor should be considered.
+// Only the following idiom is considered.
+// Example:
+// %3 = xor i32 %2, 32
+// %4 = xor i32 %2, 8224
+// %6 = getelementptr half, ptr addrspace(3) %1, i32 %3
+// %7 = getelementptr half, ptr addrspace(3) %1, i32 %4
+// GEP that corresponds to %7, looks at the binary operator %4.
+// In order for %4 to be considered, it should have a dominating xor with
+// constant offset that is disjoint with an adjusted offset.
+// If disjoint, %4 = xor i32 %2, 8224 can be treated as %4 = add i32 %3, 8192
+bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {
+
+  Value *BaseOperand = nullptr;
+  ConstantInt *CurrentConst = nullptr;
+  if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
+    return false;
+
+  // Find the most dominating Xor with the same base operand.
+  BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst);
+  if (!DominatingXor)
+    return false;
+
+  // We expect the dominating instruction to also be a 'xor-const'.
+  ConstantInt *DominatingConst = nullptr;
+  if (!match(DominatingXor,
+             m_Xor(m_Specific(BaseOperand), m_ConstantInt(DominatingConst))))
+    return false;
+
+  // Calculate the adjusted offset (difference between constants)
+  APInt AdjustedOffset = CurrentConst->getValue() - DominatingConst->getValue();
+
+  // Check disjoint conditions
+  // 1. AdjustedOffset and DominatingConst should be disjoint
+  if ((AdjustedOffset & DominatingConst->getValue()) != 0)
+    return false;
+
+  // 2. DominatingXor and AdjustedOffset should be disjoint
+  if (!MaskedValueIsZero(DominatingXor, AdjustedOffset, SimplifyQuery(DL), 0))
+    return false;
+
+  // Cache the result.
+  CachedXorInfo = XorRewriteInfo{DominatingXor, AdjustedOffset.getSExtValue()};
+  return true;
+}
+
 /// A helper function to check if reassociating through an entry in the user
 /// chain would invalidate the GEP's nuw flag.
 static bool allowsPreservingNUW(const User *U) {
@@ -805,8 +906,8 @@ static bool allowsPreservingNUW(const User *U) {
 
 Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
                                         User *&UserChainTail,
-                                        bool &PreservesNUW) {
-  ConstantOffsetExtractor Extractor(GEP->getIterator());
+                                        bool &PreservesNUW, DominatorTree *DT) {
+  ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
   // Find a non-zero constant offset first.
   APInt ConstantOffset =
       Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
@@ -825,12 +926,20 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
   return IdxWithoutConstOffset;
 }
 
-int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) {
+int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
+                                      DominatorTree *DT) {
   // If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative.
-  return ConstantOffsetExtractor(GEP->getIterator())
-      .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
-            GEP->isInBounds())
-      .getSExtValue();
+  ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
+  auto Offset = Extractor
+                    .find(Idx, /* SignExtended */ false,
+                          /* ZeroExtended */ false, GEP->isInBounds())
+                    .getSExtValue();
+
+  // Return the disjoint offset for Xor.
+  if (Extractor.CachedXorInfo)
+    return Extractor.CachedXorInfo->AdjustedOffset;
+
+  return Offset;
 }
 
 bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
@@ -866,7 +975,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
 
       // Tries to extract a constant offset from this GEP index.
       int64_t ConstantOffset =
-          ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP);
+          ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT);
       if (ConstantOffset != 0) {
         NeedsExtraction = true;
         // A GEP may have multiple indices.  We accumulate the extracted
@@ -1106,7 +1215,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
       User *UserChainTail;
       bool PreservesNUW;
       Value *NewIdx = ConstantOffsetExtractor::Extract(
-          OldIdx, GEP, UserChainTail, PreservesNUW);
+          OldIdx, GEP, UserChainTail, PreservesNUW, DT);
       if (NewIdx != nullptr) {
         // Switches to the index with the constant offset removed.
         GEP->setOperand(I, NewIdx);
diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
new file mode 100644
index 0000000000000..a0d0de070e735
--- /dev/null
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
+; RUN: -S < %s | FileCheck %s
+
+; Test that xor idiom.
+; Xors with disjoint constants 4128,8224 and 12320 must be expressed in GEPs.
+; Xors with non-disjoint constants 2336 and 8480, should not be optimized.
+define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
+; CHECK-LABEL: define amdgpu_kernel void @test1(
+; CHECK-SAME: i1 [[TMP0:%.*]], ptr addrspace(3) [[TMP1:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP0]], i32 0, i32 288
+; CHECK-NEXT:    [[TMP3:%.*]] = xor i32 [[TMP2]], 32
+; CHECK-NEXT:    [[TMP14:%.*]] = xor i32 [[TMP2]], 2336
+; CHECK-NEXT:    [[TMP5:%.*]] = xor i32 [[TMP2]], 8480
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP14]]
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP20]], i32 8192
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP7]], i32 16384
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP5]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP25]], i32 24576
+; CHECK-NEXT:    [[TMP9:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP4]], align 16
+; CHECK-NEXT:    [[TMP10:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP16]], align 16
+; CHECK-NEXT:    [[TMP17:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP21]], align 16
+; CHECK-NEXT:    [[TMP18:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP6]], align 16
+; CHECK-NEXT:    [[TMP19:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP15]], align 16
+; CHECK-NEXT:    [[TMP11:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP8]], align 16
+; CHECK-NEXT:    [[TMP12:%.*]] = fadd <8 x half> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP22:%.*]] = fadd <8 x half> [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP23:%.*]] = fadd <8 x half> [[TMP19]], [[TMP11]]
+; CHECK-NEXT:    [[TMP24:%.*]] = fadd <8 x half> [[TMP12]], [[TMP22]]
+; CHECK-NEXT:    [[TMP13:%.*]] = fadd <8 x half> [[TMP23]], [[TMP24]]
+; CHECK-NEXT:    store <8 x half> [[TMP13]], ptr addrspace(3) [[TMP1]], align 16
+; CHECK-NEXT:    ret void
+;
+entry:
+  %2 = select i1 %0, i32 0, i32 288
+  %3 = xor i32 %2, 32 // Base
+  %4 = xor i32 %2, 2336 // Not disjoint
+  %5 = xor i32 %2, 4128 // Disjoint
+  %6 = xor i32 %2, 8224 // Disjoint
+  %7 = xor i32 %2, 8480 // Not disjoint
+  %8 = xor i32 %2, 12320 // Disjoint
+  %9 = getelementptr half, ptr addrspace(3) %1, i32 %3
+  %10 = getelementptr half, ptr addrspace(3) %1, i32 %4
+  %11 = getelementptr half, ptr addrspace(3) %1, i32 %5
+  %12 = getelementptr half, ptr addrspace(3) %1, i32 %6
+  %13 = getelementptr half, ptr addrspace(3) %1, i32 %7
+  %14 = getelementptr half, ptr addrspace(3) %1, i32 %8
+  %15 = load <8 x half>, ptr addrspace(3) %9, align 16
+  %16 = load <8 x half>, ptr addrspace(3) %10, align 16
+  %17 = load <8 x half>, ptr addrspace(3) %11, align 16
+  %18 = load <8 x half>, ptr addrspace(3) %12, align 16
+  %19 = load <8 x half>, ptr addrspace(3) %13, align 16
+  %20 = load <8 x half>, ptr addrspace(3) %14, align 16
+  %21 = fadd <8 x half> %15, %16
+  %22 = fadd <8 x half> %17, %18
+  %23 = fadd <8 x half> %19, %20
+  %24 = fadd <8 x half> %21, %22
+  %25 = fadd <8 x half> %23, %24
+  store <8 x half> %25, ptr addrspace(3) %1, align 16
+  ret void
+}

>From a56ac2f27523f540a5ca286ef7905343450169f7 Mon Sep 17 00:00:00 2001
From: Sumanth Gundapaneni <sugundap at amd.com>
Date: Thu, 24 Jul 2025 11:08:38 -0500
Subject: [PATCH 2/2] Update lit test with comments

---
 .../SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll   | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
index a0d0de070e735..2cbf2ead2107e 100644
--- a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -38,12 +38,12 @@ define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
 ;
 entry:
   %2 = select i1 %0, i32 0, i32 288
-  %3 = xor i32 %2, 32 // Base
-  %4 = xor i32 %2, 2336 // Not disjoint
-  %5 = xor i32 %2, 4128 // Disjoint
-  %6 = xor i32 %2, 8224 // Disjoint
-  %7 = xor i32 %2, 8480 // Not disjoint
-  %8 = xor i32 %2, 12320 // Disjoint
+  %3 = xor i32 %2, 32 ; Base
+  %4 = xor i32 %2, 2336 ; Not disjoint
+  %5 = xor i32 %2, 4128 ; Disjoint
+  %6 = xor i32 %2, 8224 ; Disjoint
+  %7 = xor i32 %2, 8480 ; Not disjoint
+  %8 = xor i32 %2, 12320 ; Disjoint
   %9 = getelementptr half, ptr addrspace(3) %1, i32 %3
   %10 = getelementptr half, ptr addrspace(3) %1, i32 %4
   %11 = getelementptr half, ptr addrspace(3) %1, i32 %5



More information about the llvm-commits mailing list