[llvm] [Scalarizer] Ensure valid VectorSplits for each struct element in `visitExtractValueInst` (PR #128538)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 24 09:19:16 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Deric Cheung (Icohedron)

<details>
<summary>Changes</summary>

Fixes #<!-- -->127739 

The `visitExtractValueInst` is missing a check that was present in `splitCall` / `visitCallInst`.
This check ensures that each struct element has a VectorSplit, and that each VectorSplit contains the same number of elements packed per fragment.



---
Full diff: https://github.com/llvm/llvm-project/pull/128538.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+18-7) 
- (modified) llvm/test/Transforms/Scalarizer/min-bits.ll (+11) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 2b27150112ad8..820c8e12d2449 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -719,13 +719,12 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
     for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
       std::optional<VectorSplit> CurrVS =
           getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
-      // This case does not seem to happen, but it is possible for
-      // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
-      // is not returned and we will bailout of handling this call.
-      // The secondary bailout case is if NumPacked does not match.
-      // This can happen if ScalarizeMinBits is not set to the default.
-      // This means with certain ScalarizeMinBits intrinsics like frexp
-      // will only scalarize when the struct elements have the same bitness.
+      // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
+      // VectorSplit is not returned and we will bailout of handling this call.
+      // The secondary bailout case is if NumPacked does not match. This can
+      // happen if ScalarizeMinBits is not set to the default. This means with
+      // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
+      // the struct elements have the same bitness.
       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
         return false;
       if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
@@ -1083,6 +1082,18 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
   std::optional<VectorSplit> VS = getVectorSplit(VecType);
   if (!VS)
     return false;
+  for (unsigned I = 1; I < OpTy->getNumContainedTypes(); I++) {
+    std::optional<VectorSplit> CurrVS =
+        getVectorSplit(cast<FixedVectorType>(OpTy->getContainedType(I)));
+    // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
+    // VectorSplit is not returned and we will bailout of handling this call.
+    // The secondary bailout case is if NumPacked does not match. This can
+    // happen if ScalarizeMinBits is not set to the default. This means with
+    // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
+    // the struct elements have the same bitness.
+    if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
+      return false;
+  }
   IRBuilder<> Builder(&EVI);
   Scatterer Op0 = scatter(&EVI, Op, *VS);
   assert(!EVI.getIndices().empty() && "Make sure an index exists");
diff --git a/llvm/test/Transforms/Scalarizer/min-bits.ll b/llvm/test/Transforms/Scalarizer/min-bits.ll
index 97cc71626e208..377893ad7e6fd 100644
--- a/llvm/test/Transforms/Scalarizer/min-bits.ll
+++ b/llvm/test/Transforms/Scalarizer/min-bits.ll
@@ -1081,6 +1081,17 @@ define <4 x half> @call_v4f16(<4 x half> %a, <4 x half> %b) {
   ret <4 x half> %r
 }
 
+define <3 x i32> @call_v3i32(<3 x i32> %a, <3 x i32> %b) {
+; CHECK-LABEL: @call_v3i32(
+; CHECK-NEXT:    [[T:%.*]] = call { <3 x i32>, <3 x i1> } @llvm.uadd.with.overflow.v3i32(<3 x i32> [[A:%.*]], <3 x i32> [[B:%.*]])
+; CHECK-NEXT:    [[R:%.*]] = extractvalue { <3 x i32>, <3 x i1> } [[T]], 0
+; CHECK-NEXT:    ret <3 x i32> [[R]]
+;
+  %t = call { <3 x i32>, <3 x i1> } @llvm.uadd.with.overflow.v3i32(<3 x i32> %a, <3 x i32> %b)
+  %r = extractvalue { <3 x i32>, <3 x i1> } %t, 0
+  ret <3 x i32> %r
+}
+
 declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>)
 declare <3 x half> @llvm.minnum.v3f16(<3 x half>, <3 x half>)
 declare <4 x half> @llvm.minnum.v4f16(<4 x half>, <4 x half>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/128538


More information about the llvm-commits mailing list