[llvm] [SeparateConstOffsetFromGEP] Decompose constant xor operand if possible (PR #150438)
Sumanth Gundapaneni via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 24 11:01:07 PDT 2025
https://github.com/sgundapa updated https://github.com/llvm/llvm-project/pull/150438
>From 72844323dc672f2d077f169c0f3856e8f2401d96 Mon Sep 17 00:00:00 2001
From: Sumanth Gundapaneni <sugundap at amd.com>
Date: Wed, 23 Jul 2025 11:58:12 -0500
Subject: [PATCH 1/2] [SeparateConstOffsetFromGEP] Decompose constant xor
operand if possible
Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) is part
of base for memory operations. This transformation is true under the
following conditions
Check 1 - B and C are disjoint.
Check 2 - XOR(A,C) and B are disjoint.
This transformation can map these Xors in to better addressing mode and
eventually decompose them in to geps.
---
.../Scalar/SeparateConstOffsetFromGEP.cpp | 141 ++++++++++++++++--
.../AMDGPU/xor-idiom.ll | 66 ++++++++
2 files changed, 191 insertions(+), 16 deletions(-)
create mode 100644 llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
index 320b79203c0b3..203850c28787c 100644
--- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
+++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
@@ -238,16 +238,17 @@ class ConstantOffsetExtractor {
/// \p PreservesNUW Outputs whether the extraction allows preserving the
/// GEP's nuw flag, if it has one.
static Value *Extract(Value *Idx, GetElementPtrInst *GEP,
- User *&UserChainTail, bool &PreservesNUW);
+ User *&UserChainTail, bool &PreservesNUW,
+ DominatorTree *DT);
/// Looks for a constant offset from the given GEP index without extracting
/// it. It returns the numeric value of the extracted constant offset (0 if
/// failed). The meaning of the arguments are the same as Extract.
- static int64_t Find(Value *Idx, GetElementPtrInst *GEP);
+ static int64_t Find(Value *Idx, GetElementPtrInst *GEP, DominatorTree *DT);
private:
- ConstantOffsetExtractor(BasicBlock::iterator InsertionPt)
- : IP(InsertionPt), DL(InsertionPt->getDataLayout()) {}
+ ConstantOffsetExtractor(BasicBlock::iterator InsertionPt, DominatorTree *DT)
+ : IP(InsertionPt), DT(DT), DL(InsertionPt->getDataLayout()) {}
/// Searches the expression that computes V for a non-zero constant C s.t.
/// V can be reassociated into the form V' + C. If the searching is
@@ -321,6 +322,20 @@ class ConstantOffsetExtractor {
bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO,
bool NonNegative);
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor);
+
+ /// Check if Xor instruction should be considered for optimization.
+ bool shouldConsiderXor(BinaryOperator *XorInst);
+
+ /// Cache the information about Xor idiom.
+ struct XorRewriteInfo {
+ llvm::BinaryOperator *BaseXor = nullptr;
+ int64_t AdjustedOffset = 0;
+ };
+ std::optional<XorRewriteInfo> CachedXorInfo;
+
/// The path from the constant offset to the old GEP index. e.g., if the GEP
/// index is "a * b + (c + 5)". After running function find, UserChain[0] will
/// be the constant 5, UserChain[1] will be the subexpression "c + 5", and
@@ -336,6 +351,8 @@ class ConstantOffsetExtractor {
/// Insertion position of cloned instructions.
BasicBlock::iterator IP;
+ DominatorTree *DT;
+
const DataLayout &DL;
};
@@ -514,12 +531,14 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
bool ZeroExtended,
BinaryOperator *BO,
bool NonNegative) {
- // We only consider ADD, SUB and OR, because a non-zero constant found in
+ // We only consider ADD, SUB, OR and XOR, because a non-zero constant found in
// expressions composed of these operations can be easily hoisted as a
- // constant offset by reassociation.
+ // constant offset by reassociation. XOR is a special case and can be folded
+ // in to gep if the constant is proven to be disjoint.
if (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Sub &&
- BO->getOpcode() != Instruction::Or) {
+ BO->getOpcode() != Instruction::Or &&
+ BO->getOpcode() != Instruction::Xor) {
return false;
}
@@ -530,6 +549,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
+ // Handle Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return shouldConsiderXor(BO);
+
// FIXME: We don't currently support constants from the RHS of subs,
// when we are zero-extended, because we need a way to zero-extended
// them before they are negated.
@@ -740,6 +763,10 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
"UserChain, so no one should be used more than "
"once");
+ // Special case for Xor idiom.
+ if (BO->getOpcode() == Instruction::Xor)
+ return CachedXorInfo->BaseXor;
+
unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1);
assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]);
Value *NextInChain = removeConstOffset(ChainIndex - 1);
@@ -780,6 +807,80 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {
return NewBO;
}
+// Find the most dominating Xor with the same base operand.
+BinaryOperator *
+ConstantOffsetExtractor::findDominatingXor(Value *BaseOperand,
+ BinaryOperator *CurrentXor) {
+ BinaryOperator *MostDominatingXor = nullptr;
+ // Iterate over all instructions that use the BaseOperand.
+ for (User *U : BaseOperand->users()) {
+ auto *CandidateXor = dyn_cast<BinaryOperator>(U);
+
+ // Simple checks.
+ if (!CandidateXor || CandidateXor == CurrentXor)
+ continue;
+
+ // Check if the binary operator is a Xor with constant.
+ if (!match(CandidateXor, m_Xor(m_Specific(BaseOperand), m_ConstantInt())))
+ continue;
+
+ // After confirming the structure, check the dominance relationship.
+ if (DT->dominates(CandidateXor, CurrentXor))
+ // If we find a dominating Xor, keep it if it's the first one,
+ // or if it dominates the best candidate we've found so far.
+ if (!MostDominatingXor || DT->dominates(CandidateXor, MostDominatingXor))
+ MostDominatingXor = CandidateXor;
+ }
+
+ return MostDominatingXor;
+}
+
+// Check if Xor should be considered.
+// Only the following idiom is considered.
+// Example:
+// %3 = xor i32 %2, 32
+// %4 = xor i32 %2, 8224
+// %6 = getelementptr half, ptr addrspace(3) %1, i32 %3
+// %7 = getelementptr half, ptr addrspace(3) %1, i32 %4
+// GEP that corresponds to %7, looks at the binary operator %4.
+// In order for %4 to be considered, it should have a dominating xor with
+// constant offset that is disjoint with an adjusted offset.
+// If disjoint, %4 = xor i32 %2, 8224 can be treated as %4 = add i32 %3, 8192
+bool ConstantOffsetExtractor::shouldConsiderXor(BinaryOperator *XorInst) {
+
+ Value *BaseOperand = nullptr;
+ ConstantInt *CurrentConst = nullptr;
+ if (!match(XorInst, m_Xor(m_Value(BaseOperand), m_ConstantInt(CurrentConst))))
+ return false;
+
+ // Find the most dominating Xor with the same base operand.
+ BinaryOperator *DominatingXor = findDominatingXor(BaseOperand, XorInst);
+ if (!DominatingXor)
+ return false;
+
+ // We expect the dominating instruction to also be a 'xor-const'.
+ ConstantInt *DominatingConst = nullptr;
+ if (!match(DominatingXor,
+ m_Xor(m_Specific(BaseOperand), m_ConstantInt(DominatingConst))))
+ return false;
+
+ // Calculate the adjusted offset (difference between constants)
+ APInt AdjustedOffset = CurrentConst->getValue() - DominatingConst->getValue();
+
+ // Check disjoint conditions
+ // 1. AdjustedOffset and DominatingConst should be disjoint
+ if ((AdjustedOffset & DominatingConst->getValue()) != 0)
+ return false;
+
+ // 2. DominatingXor and AdjustedOffset should be disjoint
+ if (!MaskedValueIsZero(DominatingXor, AdjustedOffset, SimplifyQuery(DL), 0))
+ return false;
+
+ // Cache the result.
+ CachedXorInfo = XorRewriteInfo{DominatingXor, AdjustedOffset.getSExtValue()};
+ return true;
+}
+
/// A helper function to check if reassociating through an entry in the user
/// chain would invalidate the GEP's nuw flag.
static bool allowsPreservingNUW(const User *U) {
@@ -805,8 +906,8 @@ static bool allowsPreservingNUW(const User *U) {
Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
User *&UserChainTail,
- bool &PreservesNUW) {
- ConstantOffsetExtractor Extractor(GEP->getIterator());
+ bool &PreservesNUW, DominatorTree *DT) {
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
// Find a non-zero constant offset first.
APInt ConstantOffset =
Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
@@ -825,12 +926,20 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP,
return IdxWithoutConstOffset;
}
-int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) {
+int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP,
+ DominatorTree *DT) {
// If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative.
- return ConstantOffsetExtractor(GEP->getIterator())
- .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false,
- GEP->isInBounds())
- .getSExtValue();
+ ConstantOffsetExtractor Extractor(GEP->getIterator(), DT);
+ auto Offset = Extractor
+ .find(Idx, /* SignExtended */ false,
+ /* ZeroExtended */ false, GEP->isInBounds())
+ .getSExtValue();
+
+ // Return the disjoint offset for Xor.
+ if (Extractor.CachedXorInfo)
+ return Extractor.CachedXorInfo->AdjustedOffset;
+
+ return Offset;
}
bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize(
@@ -866,7 +975,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP,
// Tries to extract a constant offset from this GEP index.
int64_t ConstantOffset =
- ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP);
+ ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT);
if (ConstantOffset != 0) {
NeedsExtraction = true;
// A GEP may have multiple indices. We accumulate the extracted
@@ -1106,7 +1215,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) {
User *UserChainTail;
bool PreservesNUW;
Value *NewIdx = ConstantOffsetExtractor::Extract(
- OldIdx, GEP, UserChainTail, PreservesNUW);
+ OldIdx, GEP, UserChainTail, PreservesNUW, DT);
if (NewIdx != nullptr) {
// Switches to the index with the constant offset removed.
GEP->setOperand(I, NewIdx);
diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
new file mode 100644
index 0000000000000..a0d0de070e735
--- /dev/null
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \
+; RUN: -S < %s | FileCheck %s
+
+; Test that xor idiom.
+; Xors with disjoint constants 4128,8224 and 12320 must be expressed in GEPs.
+; Xors with non-disjoint constants 2336 and 8480, should not be optimized.
+define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
+; CHECK-LABEL: define amdgpu_kernel void @test1(
+; CHECK-SAME: i1 [[TMP0:%.*]], ptr addrspace(3) [[TMP1:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP0]], i32 0, i32 288
+; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 32
+; CHECK-NEXT: [[TMP14:%.*]] = xor i32 [[TMP2]], 2336
+; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP2]], 8480
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP16:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP14]]
+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP20]], i32 8192
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP7]], i32 16384
+; CHECK-NEXT: [[TMP15:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP5]]
+; CHECK-NEXT: [[TMP25:%.*]] = getelementptr half, ptr addrspace(3) [[TMP1]], i32 [[TMP3]]
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP25]], i32 24576
+; CHECK-NEXT: [[TMP9:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP4]], align 16
+; CHECK-NEXT: [[TMP10:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP16]], align 16
+; CHECK-NEXT: [[TMP17:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP21]], align 16
+; CHECK-NEXT: [[TMP18:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP6]], align 16
+; CHECK-NEXT: [[TMP19:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP15]], align 16
+; CHECK-NEXT: [[TMP11:%.*]] = load <8 x half>, ptr addrspace(3) [[TMP8]], align 16
+; CHECK-NEXT: [[TMP12:%.*]] = fadd <8 x half> [[TMP9]], [[TMP10]]
+; CHECK-NEXT: [[TMP22:%.*]] = fadd <8 x half> [[TMP17]], [[TMP18]]
+; CHECK-NEXT: [[TMP23:%.*]] = fadd <8 x half> [[TMP19]], [[TMP11]]
+; CHECK-NEXT: [[TMP24:%.*]] = fadd <8 x half> [[TMP12]], [[TMP22]]
+; CHECK-NEXT: [[TMP13:%.*]] = fadd <8 x half> [[TMP23]], [[TMP24]]
+; CHECK-NEXT: store <8 x half> [[TMP13]], ptr addrspace(3) [[TMP1]], align 16
+; CHECK-NEXT: ret void
+;
+entry:
+ %2 = select i1 %0, i32 0, i32 288
+ %3 = xor i32 %2, 32 // Base
+ %4 = xor i32 %2, 2336 // Not disjoint
+ %5 = xor i32 %2, 4128 // Disjoint
+ %6 = xor i32 %2, 8224 // Disjoint
+ %7 = xor i32 %2, 8480 // Not disjoint
+ %8 = xor i32 %2, 12320 // Disjoint
+ %9 = getelementptr half, ptr addrspace(3) %1, i32 %3
+ %10 = getelementptr half, ptr addrspace(3) %1, i32 %4
+ %11 = getelementptr half, ptr addrspace(3) %1, i32 %5
+ %12 = getelementptr half, ptr addrspace(3) %1, i32 %6
+ %13 = getelementptr half, ptr addrspace(3) %1, i32 %7
+ %14 = getelementptr half, ptr addrspace(3) %1, i32 %8
+ %15 = load <8 x half>, ptr addrspace(3) %9, align 16
+ %16 = load <8 x half>, ptr addrspace(3) %10, align 16
+ %17 = load <8 x half>, ptr addrspace(3) %11, align 16
+ %18 = load <8 x half>, ptr addrspace(3) %12, align 16
+ %19 = load <8 x half>, ptr addrspace(3) %13, align 16
+ %20 = load <8 x half>, ptr addrspace(3) %14, align 16
+ %21 = fadd <8 x half> %15, %16
+ %22 = fadd <8 x half> %17, %18
+ %23 = fadd <8 x half> %19, %20
+ %24 = fadd <8 x half> %21, %22
+ %25 = fadd <8 x half> %23, %24
+ store <8 x half> %25, ptr addrspace(3) %1, align 16
+ ret void
+}
>From a56ac2f27523f540a5ca286ef7905343450169f7 Mon Sep 17 00:00:00 2001
From: Sumanth Gundapaneni <sugundap at amd.com>
Date: Thu, 24 Jul 2025 11:08:38 -0500
Subject: [PATCH 2/2] Update lit test with comments
---
.../SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
index a0d0de070e735..2cbf2ead2107e 100644
--- a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
+++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-idiom.ll
@@ -38,12 +38,12 @@ define amdgpu_kernel void @test1(i1 %0, ptr addrspace(3) %1) {
;
entry:
%2 = select i1 %0, i32 0, i32 288
- %3 = xor i32 %2, 32 // Base
- %4 = xor i32 %2, 2336 // Not disjoint
- %5 = xor i32 %2, 4128 // Disjoint
- %6 = xor i32 %2, 8224 // Disjoint
- %7 = xor i32 %2, 8480 // Not disjoint
- %8 = xor i32 %2, 12320 // Disjoint
+ %3 = xor i32 %2, 32 ; Base
+ %4 = xor i32 %2, 2336 ; Not disjoint
+ %5 = xor i32 %2, 4128 ; Disjoint
+ %6 = xor i32 %2, 8224 ; Disjoint
+ %7 = xor i32 %2, 8480 ; Not disjoint
+ %8 = xor i32 %2, 12320 ; Disjoint
%9 = getelementptr half, ptr addrspace(3) %1, i32 %3
%10 = getelementptr half, ptr addrspace(3) %1, i32 %4
%11 = getelementptr half, ptr addrspace(3) %1, i32 %5
More information about the llvm-commits
mailing list