[llvm] Decompose gep of complex type struct to its element type (PR #107848)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 22 18:20:09 PDT 2024
https://github.com/vfdff updated https://github.com/llvm/llvm-project/pull/107848
>From 43a6f2b3ffc3cd8a93dd225100711f1e189140c2 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Sat, 7 Sep 2024 06:16:07 -0400
Subject: [PATCH 1/2] Decompos gep of complex type struct to its element type
Similar to PR96606, but this PR address the scenario of all zero indices
except the last indice, which is not a const value.
We usual wide the index of gep to be same width as pointer width, so the
index of getelementptr may be offen extend to i64 for AArch64 for example.
Vectorization will choose the VL according data type, so it may be
<vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to
following IR node, it need multiply by 2 in codegen, so we can't assume
the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it
splits the node llvm.masked.gather.nxv4f32.nxv4p0.
```
> %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @mdlComplex, i64 0, <vscale x 4 x i64> %3
> %wide.masked.gather = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> %4, i32 4, <vscale x 4 x i1> %active.lane.mask, <vscale x 4 x float> poison)
```
This PR decompos gep of complex type struct to its element type,
then the index of getelementptr doesn't need multiply, so it dones't
need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset
extend from i32.
Fix https://github.com/llvm/llvm-project/issues/107825
---
llvm/include/llvm/IR/Instructions.h | 2 ++
llvm/lib/IR/Instructions.cpp | 11 +++++++
.../InstCombine/InstructionCombining.cpp | 32 +++++++++++++++++++
.../LoopVectorize/reduction-inloop.ll | 24 ++++++++------
4 files changed, 59 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index ab3321ee755717..c5a2e2cb442d29 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;
+ bool hasAllZeroIndicesExceptLast() const;
+
/// Set nowrap flags for GEP instruction.
void setNoWrapFlags(GEPNoWrapFlags NW);
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 19da1f60d424d2..8363ff2c070d69 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1555,6 +1555,17 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
return true;
}
+/// hasAllZeroIndicesExceptLast - Return true if all of the indices of this GEP
+/// are zero except the last indice.
+bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
+ for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
+ if (!isa<ConstantInt>(getOperand(i)) ||
+ !cast<ConstantInt>(getOperand(i))->isZero())
+ return false;
+ }
+ return true;
+}
+
void GetElementPtrInst::setNoWrapFlags(GEPNoWrapFlags NW) {
SubclassOptionalData = NW.getRaw();
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8195e0539305cc..1de6d189ea7ae0 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2805,6 +2805,38 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
GEP.getNoWrapFlags()));
}
+ // For complex type: %"struct.std::complex" = type { { float, float } }
+ // Canonicalize
+ // - %idxprom = sext i32 %Off to i64
+ // - inbounds [100 x %"struct.std::complex"], ptr @p, i64 0, i64 %idx
+ // into
+ // - %idxprom.scale = shl nsw i32 %Off, 1
+ // - %1 = sext i32 %idxprom.scale to i64
+ // - getelementptr inbounds float, ptr @p, i64 %1
+ auto *GepResElTy = GEP.getResultElementType();
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 1)
+ GepResElTy = GepResElTy->getStructElementType(0);
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 2 &&
+ GepResElTy->getStructElementType(0) ==
+ GepResElTy->getStructElementType(1) &&
+ GEP.hasAllZeroIndicesExceptLast()) {
+ unsigned LastIndice = GEP.getNumOperands() - 1;
+ Value *LastOp = GEP.getOperand(LastIndice);
+ if (auto *SExtI = dyn_cast<SExtInst>(LastOp)) {
+ GEPOperator *GEPOp = cast<GEPOperator>(&GEP);
+ bool NSW = GEPOp->hasNoUnsignedSignedWrap();
+ bool NUW = GEPOp->hasNoUnsignedWrap();
+ // We'll let instcombine(mul) convert this to a shl if possible.
+ auto IntTy = SExtI->getOperand(0)->getType();
+ Value *Offset =
+ Builder.CreateMul(SExtI->getOperand(0), ConstantInt::get(IntTy, 2),
+ SExtI->getName() + ".scale", NUW, NSW);
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(GepResElTy->getStructElementType(0), PtrOp,
+ Offset, "", GEP.getNoWrapFlags()));
+ }
+ }
+
// Canonicalize
// - scalable GEPs to an explicit offset using the llvm.vscale intrinsic.
// This has better support in BasicAA.
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index 7c5d6a1edf0b4b..e55a7931818939 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1316,8 +1316,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP19]], i64 0
; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
; CHECK: pred.load.if:
-; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP21]]
+; CHECK-NEXT: [[DOTSCALE:%.*]] = shl nsw i32 [[INDEX]], 1
+; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[DOTSCALE]] to i64
+; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
; CHECK-NEXT: [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i32> poison, i32 [[TMP23]], i64 0
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE]]
@@ -1326,8 +1327,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i1> [[TMP19]], i64 1
; CHECK-NEXT: br i1 [[TMP26]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
; CHECK: pred.load.if1:
-; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[TMP0]] to i64
-; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP27]]
+; CHECK-NEXT: [[DOTSCALE7:%.*]] = shl nsw i32 [[TMP0]], 1
+; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[DOTSCALE7]] to i64
+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP27]]
; CHECK-NEXT: [[TMP29:%.*]] = load i32, ptr [[TMP28]], align 4
; CHECK-NEXT: [[TMP30:%.*]] = insertelement <4 x i32> [[TMP25]], i32 [[TMP29]], i64 1
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE2]]
@@ -1336,8 +1338,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i1> [[TMP19]], i64 2
; CHECK-NEXT: br i1 [[TMP32]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
; CHECK: pred.load.if3:
-; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[TMP1]] to i64
-; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP33]]
+; CHECK-NEXT: [[DOTSCALE8:%.*]] = shl nsw i32 [[TMP1]], 1
+; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[DOTSCALE8]] to i64
+; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP33]]
; CHECK-NEXT: [[TMP35:%.*]] = load i32, ptr [[TMP34]], align 4
; CHECK-NEXT: [[TMP36:%.*]] = insertelement <4 x i32> [[TMP31]], i32 [[TMP35]], i64 2
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE4]]
@@ -1346,8 +1349,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i1> [[TMP19]], i64 3
; CHECK-NEXT: br i1 [[TMP38]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.if5:
-; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[TMP2]] to i64
-; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP39]]
+; CHECK-NEXT: [[DOTSCALE9:%.*]] = shl nsw i32 [[TMP2]], 1
+; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[DOTSCALE9]] to i64
+; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP39]]
; CHECK-NEXT: [[TMP41:%.*]] = load i32, ptr [[TMP40]], align 4
; CHECK-NEXT: [[TMP42:%.*]] = insertelement <4 x i32> [[TMP37]], i32 [[TMP41]], i64 3
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE6]]
@@ -1355,8 +1359,8 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
; CHECK-NEXT: [[NOT_:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
-; CHECK-NEXT: [[DOTNOT7:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
-; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT7]] to i4
+; CHECK-NEXT: [[DOTNOT10:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
+; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT10]] to i4
; CHECK-NEXT: [[TMP46:%.*]] = call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 [[TMP45]])
; CHECK-NEXT: [[TMP47:%.*]] = zext nneg i4 [[TMP46]] to i32
; CHECK-NEXT: [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
>From 52ef22bd3a0cee9f572b127e885f93b27fbaf273 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Wed, 18 Sep 2024 02:13:09 -0400
Subject: [PATCH 2/2] Fix comment
---
llvm/lib/IR/Instructions.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 8363ff2c070d69..bebe99dbe9bfe9 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1559,8 +1559,8 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
/// are zero except the last indice.
bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
- if (!isa<ConstantInt>(getOperand(i)) ||
- !cast<ConstantInt>(getOperand(i))->isZero())
+ ConstantInt *Val = dyn_cast<ConstantInt>(getOperand(i));
+ if (!Val || !Val->isZero())
return false;
}
return true;
More information about the llvm-commits
mailing list