[llvm] Decompose gep of complex type struct to its element type (PR #107848)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 23 18:52:08 PDT 2024
https://github.com/vfdff updated https://github.com/llvm/llvm-project/pull/107848
>From 43a6f2b3ffc3cd8a93dd225100711f1e189140c2 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Sat, 7 Sep 2024 06:16:07 -0400
Subject: [PATCH 1/3] Decompos gep of complex type struct to its element type
Similar to PR96606, but this PR address the scenario of all zero indices
except the last indice, which is not a const value.
We usual wide the index of gep to be same width as pointer width, so the
index of getelementptr may be offen extend to i64 for AArch64 for example.
Vectorization will choose the VL according data type, so it may be
<vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to
following IR node, it need multiply by 2 in codegen, so we can't assume
the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it
splits the node llvm.masked.gather.nxv4f32.nxv4p0.
```
> %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @mdlComplex, i64 0, <vscale x 4 x i64> %3
> %wide.masked.gather = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> %4, i32 4, <vscale x 4 x i1> %active.lane.mask, <vscale x 4 x float> poison)
```
This PR decompos gep of complex type struct to its element type,
then the index of getelementptr doesn't need multiply, so it dones't
need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset
extend from i32.
Fix https://github.com/llvm/llvm-project/issues/107825
---
llvm/include/llvm/IR/Instructions.h | 2 ++
llvm/lib/IR/Instructions.cpp | 11 +++++++
.../InstCombine/InstructionCombining.cpp | 32 +++++++++++++++++++
.../LoopVectorize/reduction-inloop.ll | 24 ++++++++------
4 files changed, 59 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index ab3321ee755717..c5a2e2cb442d29 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;
+ bool hasAllZeroIndicesExceptLast() const;
+
/// Set nowrap flags for GEP instruction.
void setNoWrapFlags(GEPNoWrapFlags NW);
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 19da1f60d424d2..8363ff2c070d69 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1555,6 +1555,17 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
return true;
}
+/// hasAllZeroIndicesExceptLast - Return true if all of the indices of this GEP
+/// are zero except the last indice.
+bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
+ for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
+ if (!isa<ConstantInt>(getOperand(i)) ||
+ !cast<ConstantInt>(getOperand(i))->isZero())
+ return false;
+ }
+ return true;
+}
+
void GetElementPtrInst::setNoWrapFlags(GEPNoWrapFlags NW) {
SubclassOptionalData = NW.getRaw();
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8195e0539305cc..1de6d189ea7ae0 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2805,6 +2805,38 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
GEP.getNoWrapFlags()));
}
+ // For complex type: %"struct.std::complex" = type { { float, float } }
+ // Canonicalize
+ // - %idxprom = sext i32 %Off to i64
+ // - inbounds [100 x %"struct.std::complex"], ptr @p, i64 0, i64 %idx
+ // into
+ // - %idxprom.scale = shl nsw i32 %Off, 1
+ // - %1 = sext i32 %idxprom.scale to i64
+ // - getelementptr inbounds float, ptr @p, i64 %1
+ auto *GepResElTy = GEP.getResultElementType();
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 1)
+ GepResElTy = GepResElTy->getStructElementType(0);
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 2 &&
+ GepResElTy->getStructElementType(0) ==
+ GepResElTy->getStructElementType(1) &&
+ GEP.hasAllZeroIndicesExceptLast()) {
+ unsigned LastIndice = GEP.getNumOperands() - 1;
+ Value *LastOp = GEP.getOperand(LastIndice);
+ if (auto *SExtI = dyn_cast<SExtInst>(LastOp)) {
+ GEPOperator *GEPOp = cast<GEPOperator>(&GEP);
+ bool NSW = GEPOp->hasNoUnsignedSignedWrap();
+ bool NUW = GEPOp->hasNoUnsignedWrap();
+ // We'll let instcombine(mul) convert this to a shl if possible.
+ auto IntTy = SExtI->getOperand(0)->getType();
+ Value *Offset =
+ Builder.CreateMul(SExtI->getOperand(0), ConstantInt::get(IntTy, 2),
+ SExtI->getName() + ".scale", NUW, NSW);
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(GepResElTy->getStructElementType(0), PtrOp,
+ Offset, "", GEP.getNoWrapFlags()));
+ }
+ }
+
// Canonicalize
// - scalable GEPs to an explicit offset using the llvm.vscale intrinsic.
// This has better support in BasicAA.
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index 7c5d6a1edf0b4b..e55a7931818939 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1316,8 +1316,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP19]], i64 0
; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
; CHECK: pred.load.if:
-; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP21]]
+; CHECK-NEXT: [[DOTSCALE:%.*]] = shl nsw i32 [[INDEX]], 1
+; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[DOTSCALE]] to i64
+; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
; CHECK-NEXT: [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i32> poison, i32 [[TMP23]], i64 0
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE]]
@@ -1326,8 +1327,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i1> [[TMP19]], i64 1
; CHECK-NEXT: br i1 [[TMP26]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
; CHECK: pred.load.if1:
-; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[TMP0]] to i64
-; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP27]]
+; CHECK-NEXT: [[DOTSCALE7:%.*]] = shl nsw i32 [[TMP0]], 1
+; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[DOTSCALE7]] to i64
+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP27]]
; CHECK-NEXT: [[TMP29:%.*]] = load i32, ptr [[TMP28]], align 4
; CHECK-NEXT: [[TMP30:%.*]] = insertelement <4 x i32> [[TMP25]], i32 [[TMP29]], i64 1
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE2]]
@@ -1336,8 +1338,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i1> [[TMP19]], i64 2
; CHECK-NEXT: br i1 [[TMP32]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
; CHECK: pred.load.if3:
-; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[TMP1]] to i64
-; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP33]]
+; CHECK-NEXT: [[DOTSCALE8:%.*]] = shl nsw i32 [[TMP1]], 1
+; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[DOTSCALE8]] to i64
+; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP33]]
; CHECK-NEXT: [[TMP35:%.*]] = load i32, ptr [[TMP34]], align 4
; CHECK-NEXT: [[TMP36:%.*]] = insertelement <4 x i32> [[TMP31]], i32 [[TMP35]], i64 2
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE4]]
@@ -1346,8 +1349,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i1> [[TMP19]], i64 3
; CHECK-NEXT: br i1 [[TMP38]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.if5:
-; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[TMP2]] to i64
-; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP39]]
+; CHECK-NEXT: [[DOTSCALE9:%.*]] = shl nsw i32 [[TMP2]], 1
+; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[DOTSCALE9]] to i64
+; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP39]]
; CHECK-NEXT: [[TMP41:%.*]] = load i32, ptr [[TMP40]], align 4
; CHECK-NEXT: [[TMP42:%.*]] = insertelement <4 x i32> [[TMP37]], i32 [[TMP41]], i64 3
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE6]]
@@ -1355,8 +1359,8 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
; CHECK-NEXT: [[NOT_:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
-; CHECK-NEXT: [[DOTNOT7:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
-; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT7]] to i4
+; CHECK-NEXT: [[DOTNOT10:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
+; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT10]] to i4
; CHECK-NEXT: [[TMP46:%.*]] = call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 [[TMP45]])
; CHECK-NEXT: [[TMP47:%.*]] = zext nneg i4 [[TMP46]] to i32
; CHECK-NEXT: [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
>From 52ef22bd3a0cee9f572b127e885f93b27fbaf273 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Wed, 18 Sep 2024 02:13:09 -0400
Subject: [PATCH 2/3] Fix comment
---
llvm/lib/IR/Instructions.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 8363ff2c070d69..bebe99dbe9bfe9 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1559,8 +1559,8 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
/// are zero except the last indice.
bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
- if (!isa<ConstantInt>(getOperand(i)) ||
- !cast<ConstantInt>(getOperand(i))->isZero())
+ ConstantInt *Val = dyn_cast<ConstantInt>(getOperand(i));
+ if (!Val || !Val->isZero())
return false;
}
return true;
>From 02fa15255bc5c5d30d62343a9824fe5195b43cc2 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Mon, 23 Sep 2024 05:28:29 -0400
Subject: [PATCH 3/3] Fix comment, add test for vector
---
llvm/include/llvm/IR/Instructions.h | 2 ++
.../Transforms/InstCombine/gep-complex.ll | 36 +++++++++++++++++++
2 files changed, 38 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/gep-complex.ll
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index c5a2e2cb442d29..f6a585c5b8abba 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;
+ /// Return true if all of the indices of this GEP are zero except the last
+ /// indice.
bool hasAllZeroIndicesExceptLast() const;
/// Set nowrap flags for GEP instruction.
diff --git a/llvm/test/Transforms/InstCombine/gep-complex.ll b/llvm/test/Transforms/InstCombine/gep-complex.ll
new file mode 100644
index 00000000000000..f1c639262148cc
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/gep-complex.ll
@@ -0,0 +1,36 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+%"class.std::__1::complex" = type { float, float }
+ at mdlComplex = dso_local global [10000 x %"class.std::__1::complex"] zeroinitializer, align 4
+
+define float @decompose_complex_scalar(ptr %array) {
+; CHECK-LABEL: @decompose_complex_scalar(
+; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[ARRAY:%.*]], align 4
+; CHECK-NEXT: [[SEXTVAL_SCALE:%.*]] = shl nsw i32 [[VAL]], 1
+; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[SEXTVAL_SCALE]] to i64
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr @mdlComplex, i64 [[TMP1]]
+; CHECK-NEXT: [[RES:%.*]] = load float, ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT: ret float [[RES]]
+;
+ %val = load i32, ptr %array, align 4
+ %sextVal = sext i32 %val to i64
+ %arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], ptr @mdlComplex, i32 0, i64 %sextVal
+ %res = load float, ptr %arrayidx, align 4
+ ret float %res
+}
+
+define <4 x ptr> @decompose_complex_vector(ptr %array) {
+; CHECK-LABEL: @decompose_complex_vector(
+; CHECK-NEXT: [[VAL:%.*]] = load <4 x i32>, ptr [[ARRAY:%.*]], align 4
+; CHECK-NEXT: [[SEXTVAL_SCALE:%.*]] = shl nsw <4 x i32> [[VAL]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT: [[TMP1:%.*]] = sext <4 x i32> [[SEXTVAL_SCALE]] to <4 x i64>
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr @mdlComplex, <4 x i64> [[TMP1]]
+; CHECK-NEXT: ret <4 x ptr> [[ARRAYIDX]]
+;
+ %val = load <4 x i32>, ptr %array, align 4
+ %sextVal = sext <4 x i32> %val to <4 x i64>
+ %arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], ptr @mdlComplex, i32 0, <4 x i64> %sextVal
+ ret <4 x ptr> %arrayidx
+}
+
More information about the llvm-commits
mailing list