[llvm] [Scalarizer] Ensure valid VectorSplits for each struct element in `visitExtractValueInst` (PR #128538)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 24 09:19:16 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Deric Cheung (Icohedron)
<details>
<summary>Changes</summary>
Fixes #<!-- -->127739
The `visitExtractValueInst` is missing a check that was present in `splitCall` / `visitCallInst`.
This check ensures that each struct element has a VectorSplit, and that each VectorSplit contains the same number of elements packed per fragment.
---
Full diff: https://github.com/llvm/llvm-project/pull/128538.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+18-7)
- (modified) llvm/test/Transforms/Scalarizer/min-bits.ll (+11)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 2b27150112ad8..820c8e12d2449 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -719,13 +719,12 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
std::optional<VectorSplit> CurrVS =
getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
- // This case does not seem to happen, but it is possible for
- // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
- // is not returned and we will bailout of handling this call.
- // The secondary bailout case is if NumPacked does not match.
- // This can happen if ScalarizeMinBits is not set to the default.
- // This means with certain ScalarizeMinBits intrinsics like frexp
- // will only scalarize when the struct elements have the same bitness.
+ // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
+ // VectorSplit is not returned and we will bailout of handling this call.
+ // The secondary bailout case is if NumPacked does not match. This can
+ // happen if ScalarizeMinBits is not set to the default. This means with
+ // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
+ // the struct elements have the same bitness.
if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
return false;
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
@@ -1083,6 +1082,18 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
std::optional<VectorSplit> VS = getVectorSplit(VecType);
if (!VS)
return false;
+ for (unsigned I = 1; I < OpTy->getNumContainedTypes(); I++) {
+ std::optional<VectorSplit> CurrVS =
+ getVectorSplit(cast<FixedVectorType>(OpTy->getContainedType(I)));
+ // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
+ // VectorSplit is not returned and we will bailout of handling this call.
+ // The secondary bailout case is if NumPacked does not match. This can
+ // happen if ScalarizeMinBits is not set to the default. This means with
+ // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
+ // the struct elements have the same bitness.
+ if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
+ return false;
+ }
IRBuilder<> Builder(&EVI);
Scatterer Op0 = scatter(&EVI, Op, *VS);
assert(!EVI.getIndices().empty() && "Make sure an index exists");
diff --git a/llvm/test/Transforms/Scalarizer/min-bits.ll b/llvm/test/Transforms/Scalarizer/min-bits.ll
index 97cc71626e208..377893ad7e6fd 100644
--- a/llvm/test/Transforms/Scalarizer/min-bits.ll
+++ b/llvm/test/Transforms/Scalarizer/min-bits.ll
@@ -1081,6 +1081,17 @@ define <4 x half> @call_v4f16(<4 x half> %a, <4 x half> %b) {
ret <4 x half> %r
}
+define <3 x i32> @call_v3i32(<3 x i32> %a, <3 x i32> %b) {
+; CHECK-LABEL: @call_v3i32(
+; CHECK-NEXT: [[T:%.*]] = call { <3 x i32>, <3 x i1> } @llvm.uadd.with.overflow.v3i32(<3 x i32> [[A:%.*]], <3 x i32> [[B:%.*]])
+; CHECK-NEXT: [[R:%.*]] = extractvalue { <3 x i32>, <3 x i1> } [[T]], 0
+; CHECK-NEXT: ret <3 x i32> [[R]]
+;
+ %t = call { <3 x i32>, <3 x i1> } @llvm.uadd.with.overflow.v3i32(<3 x i32> %a, <3 x i32> %b)
+ %r = extractvalue { <3 x i32>, <3 x i1> } %t, 0
+ ret <3 x i32> %r
+}
+
declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>)
declare <3 x half> @llvm.minnum.v3f16(<3 x half>, <3 x half>)
declare <4 x half> @llvm.minnum.v4f16(<4 x half>, <4 x half>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/128538
More information about the llvm-commits
mailing list