[llvm] [Scalarizer][DirectX] support structs return types (PR #111569)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 8 10:42:58 PDT 2024


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/111569

Based on this RFC: https://discourse.llvm.org/t/rfc-allow-the-scalarizer-pass-to-scalarize-vectors-returned-in-structs/82306

LLVM intrinsics do not support out params. To get around this limitation implementers will make intrinsics return structs to capture a return type and an out param. This implementation detail should not impact scalarization since these cases should be elementwise operations.

## Three changes are needed. 
- The CallInst visitor needs to be updated to handle Structs
- A new visitor is needed for `ExtractValue` instructions
- finsh needs to be update to handle structs so that insert elements are properly propogated.

## Testing changes
- Add support for `llvm.frexp`
- Add support for `llvm.dx.splitdouble`

>From 3a72753c437c350aacfc5f8b5f6be7d0a7672ae0 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 30 Sep 2024 10:11:38 -0400
Subject: [PATCH 1/3] [Scalarizer] A change to let the scalarizer pass be able
 to scalarize structs

---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  3 +
 .../DirectX/DirectXTargetTransformInfo.cpp    |  1 +
 llvm/lib/Transforms/Scalar/Scalarizer.cpp     | 69 ++++++++++++++++++-
 llvm/test/CodeGen/DirectX/split-double.ll     | 10 +++
 4 files changed, 81 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/split-double.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f2b9e286ebb476..ae9f0aea904f4a 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,5 +86,8 @@ def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
+    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
+
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index be714b5c87895a..4ddf39a4337df6 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -28,6 +28,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
+  case Intrinsic::dx_splitdouble:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 72728c0f839e5d..7505654c23a70d 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -197,6 +197,11 @@ struct VectorLayout {
   uint64_t SplitSize = 0;
 };
 
+static bool isStructOfVectors(Type *Ty) {
+  return isa<StructType>(Ty) && Ty->getNumContainedTypes() > 0 &&
+         isa<FixedVectorType>(Ty->getContainedType(0));
+}
+
 /// Concatenate the given fragments to a single vector value of the type
 /// described in @p VS.
 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +281,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
   bool visitBitCastInst(BitCastInst &BCI);
   bool visitInsertElementInst(InsertElementInst &IEI);
   bool visitExtractElementInst(ExtractElementInst &EEI);
+  bool visitExtractValueInst(ExtractValueInst &EVI);
   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
   bool visitPHINode(PHINode &PHI);
   bool visitLoadInst(LoadInst &LI);
@@ -552,7 +558,10 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
 // Determine how Ty is split, if at all.
 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
   VectorSplit Split;
-  Split.VecTy = dyn_cast<FixedVectorType>(Ty);
+  if (isStructOfVectors(Ty))
+    Split.VecTy = cast<FixedVectorType>(Ty->getContainedType(0));
+  else
+    Split.VecTy = dyn_cast<FixedVectorType>(Ty);
   if (!Split.VecTy)
     return {};
 
@@ -1029,6 +1038,33 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
   return true;
 }
 
+bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
+  Value *Op = EVI.getOperand(0);
+  Type *OpTy = Op->getType();
+  ValueVector Res;
+  if (!isStructOfVectors(OpTy))
+    return false;
+  // Note: isStructOfVectors is also used in getVectorSplit.
+  // The intent is to bail on this visit if it isn't a struct
+  // of vectors. Downside is that when it is true we do two
+  // isStructOfVectors calls.
+  std::optional<VectorSplit> VS = getVectorSplit(OpTy);
+  if (!VS)
+    return false;
+  Scatterer Op0 = scatter(&EVI, Op, *VS);
+  assert(!EVI.getIndices().empty() && "Make sure an index exists");
+  // Note for our use case we only care about the top level index.
+  unsigned Index = EVI.getIndices()[0];
+  for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
+    Value *ResElem = Builder.CreateExtractValue(
+        Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
+    Res.push_back(ResElem);
+  }
+  // replaceUses(&EVI, Res);
+  gather(&EVI, Res, *VS);
+  return true;
+}
+
 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
   if (!VS)
@@ -1195,7 +1231,7 @@ bool ScalarizerVisitor::finish() {
     if (!Op->use_empty()) {
       // The value is still needed, so recreate it using a series of
       // insertelements and/or shufflevectors.
-      Value *Res;
+      Value *Res = nullptr;
       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
         BasicBlock *BB = Op->getParent();
         IRBuilder<> Builder(Op);
@@ -1208,6 +1244,35 @@ bool ScalarizerVisitor::finish() {
         Res = concatenate(Builder, CV, VS, Op->getName());
 
         Res->takeName(Op);
+      } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
+        BasicBlock *BB = Op->getParent();
+        IRBuilder<> Builder(Op);
+        if (isa<PHINode>(Op))
+          Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+        // Iterate over each element in the struct
+        uint NumOfStructElements = Ty->getNumElements();
+        SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
+        for (unsigned I = 0; I < NumOfStructElements; ++I) {
+          for (auto *CVelem : CV) {
+            Value *Elem = Builder.CreateExtractValue(
+                CVelem, I, Op->getName() + ".elem" + std::to_string(I));
+            ElemCV[I].push_back(Elem);
+          }
+        }
+        Res = PoisonValue::get(Ty);
+        for (unsigned I = 0; I < NumOfStructElements; ++I) {
+          Type *ElemTy = Ty->getElementType(I);
+          assert(isa<FixedVectorType>(ElemTy) &&
+                 "Only Structs of all FixedVectorType supported");
+          VectorSplit VS = *getVectorSplit(ElemTy);
+          assert(VS.NumFragments == CV.size());
+
+          Value *ConcatenatedVector =
+              concatenate(Builder, ElemCV[I], VS, Op->getName());
+          Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
+                                          Op->getName() + ".insert");
+        }
       } else {
         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
         Res = CV[0];
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
new file mode 100644
index 00000000000000..7d3c28efbc63c1
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -0,0 +1,10 @@
+
+; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) local_unnamed_addr {
+    %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+    %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+    %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+    %3 = add <3 x i32> %1, %2
+    ret <3 x i32> %3
+}

>From e59ae739aa325814fe7a72da588a16cefba96078 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Tue, 8 Oct 2024 00:59:15 -0400
Subject: [PATCH 2/3] Add support for frexp. Move vector look up to just
 callInst and extractValue instruction visits

---
 llvm/include/llvm/IR/IntrinsicsDirectX.td |  1 -
 llvm/lib/Transforms/Scalar/Scalarizer.cpp | 53 ++++++++++++------
 llvm/test/CodeGen/DirectX/split-double.ll | 36 ++++++++++---
 llvm/test/Transforms/Scalarizer/frexp.ll  | 66 +++++++++++++++++++++++
 4 files changed, 133 insertions(+), 23 deletions(-)
 create mode 100644 llvm/test/Transforms/Scalarizer/frexp.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ae9f0aea904f4a..5f0f856df8e2b0 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -88,6 +88,5 @@ def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
-
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 7505654c23a70d..f6a7230a472de5 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -197,9 +197,15 @@ struct VectorLayout {
   uint64_t SplitSize = 0;
 };
 
-static bool isStructOfVectors(Type *Ty) {
-  return isa<StructType>(Ty) && Ty->getNumContainedTypes() > 0 &&
-         isa<FixedVectorType>(Ty->getContainedType(0));
+static bool isStructAllVectors(Type *Ty) {
+  if (!isa<StructType>(Ty))
+    return false;
+
+  for(unsigned I = 0; I < Ty->getNumContainedTypes(); I++)
+    if (!isa<FixedVectorType>(Ty->getContainedType(I)))
+      return false;
+
+  return true;
 }
 
 /// Concatenate the given fragments to a single vector value of the type
@@ -558,10 +564,7 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
 // Determine how Ty is split, if at all.
 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
   VectorSplit Split;
-  if (isStructOfVectors(Ty))
-    Split.VecTy = cast<FixedVectorType>(Ty->getContainedType(0));
-  else
-    Split.VecTy = dyn_cast<FixedVectorType>(Ty);
+  Split.VecTy = dyn_cast<FixedVectorType>(Ty);
   if (!Split.VecTy)
     return {};
 
@@ -676,6 +679,10 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
 bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
   if (isTriviallyVectorizable(ID))
     return true;
+  switch (ID) {
+    case Intrinsic::frexp:
+    return true;
+  }
   return Intrinsic::isTargetIntrinsic(ID) &&
          TTI->isTargetIntrinsicTriviallyScalarizable(ID);
 }
@@ -683,7 +690,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
-  std::optional<VectorSplit> VS = getVectorSplit(CI.getType());
+  Type* CallType = CI.getType();
+  bool areAllVectors = isStructAllVectors(CallType);
+   std::optional<VectorSplit> VS;
+  if (areAllVectors)
+    VS = getVectorSplit(CallType->getContainedType(0));
+  else
+    VS = getVectorSplit(CallType);
   if (!VS)
     return false;
 
@@ -708,6 +721,18 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
     Tys.push_back(VS->SplitTy);
 
+  if(areAllVectors) {
+    Type* PrevType = CallType->getContainedType(0);
+    Type* CallType = CI.getType();
+    for(unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
+      Type* CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
+      if(PrevType != CurrType) {
+        std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
+        Tys.push_back(CurrVS->SplitTy);
+        PrevType = CurrType;
+      }
+    }
+  }
   // Assumes that any vector type has the same number of elements as the return
   // vector type, which is true for all current intrinsics.
   for (unsigned I = 0; I != NumArgs; ++I) {
@@ -1042,15 +1067,13 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
   Value *Op = EVI.getOperand(0);
   Type *OpTy = Op->getType();
   ValueVector Res;
-  if (!isStructOfVectors(OpTy))
+  if (!isStructAllVectors(OpTy))
     return false;
-  // Note: isStructOfVectors is also used in getVectorSplit.
-  // The intent is to bail on this visit if it isn't a struct
-  // of vectors. Downside is that when it is true we do two
-  // isStructOfVectors calls.
-  std::optional<VectorSplit> VS = getVectorSplit(OpTy);
+  Type* VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
+  std::optional<VectorSplit> VS = getVectorSplit(VecType);
   if (!VS)
     return false;
+  IRBuilder<> Builder(&EVI);
   Scatterer Op0 = scatter(&EVI, Op, *VS);
   assert(!EVI.getIndices().empty() && "Make sure an index exists");
   // Note for our use case we only care about the top level index.
@@ -1251,7 +1274,7 @@ bool ScalarizerVisitor::finish() {
           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
 
         // Iterate over each element in the struct
-        uint NumOfStructElements = Ty->getNumElements();
+        unsigned NumOfStructElements = Ty->getNumElements();
         SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
         for (unsigned I = 0; I < NumOfStructElements; ++I) {
           for (auto *CVelem : CV) {
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
index 7d3c28efbc63c1..4fc5fdd1922a2c 100644
--- a/llvm/test/CodeGen/DirectX/split-double.ll
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -1,10 +1,32 @@
+; RUN: opt -S -scalarizer  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-; RUN: opt -S -scalarizer -scalarize-load-store -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+define void @test_vector_double_split_void(<3 x double> noundef %d) {
+  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+  ret void
+}
 
-define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) local_unnamed_addr {
-    %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
-    %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
-    %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
-    %3 = add <3 x i32> %1, %2
-    ret <3 x i32> %3
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
+  ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <3 x double> %d, i64 1
+  ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+  ; CHECK: [[ee2:%.*]] = extractelement <3 x double> %d, i64 2
+  ; CHECK: [[ie2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee2]])
+  ; CHECK: [[ev00:%.*]] = extractvalue { i32, i32 } [[ie0]], 0
+  ; CHECK: [[ev01:%.*]] = extractvalue { i32, i32 } [[ie1]], 0
+  ; CHECK: [[ev02:%.*]] = extractvalue { i32, i32 } [[ie2]], 0
+  ; CHECK: [[ev10:%.*]] = extractvalue { i32, i32 } [[ie0]], 1
+  ; CHECK: [[ev11:%.*]] = extractvalue { i32, i32 } [[ie1]], 1
+  ; CHECK: [[ev12:%.*]] = extractvalue { i32, i32 } [[ie2]], 1
+  ; CHECK: [[add1:%.*]] = add i32 [[ev00]], [[ev10]]
+  ; CHECK: [[add2:%.*]] = add i32 [[ev01]], [[ev11]]
+  ; CHECK: [[add3:%.*]] = add i32 [[ev02]], [[ev12]]
+  ; CHECK: insertelement <3 x i32> poison, i32 [[add1]], i64 0
+  ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add2]], i64 1
+  ; CHECK: insertelement <3 x i32> %{{.*}}, i32 [[add3]], i64 2
+  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+  %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+  %3 = add <3 x i32> %1, %2
+  ret <3 x i32> %3
 }
diff --git a/llvm/test/Transforms/Scalarizer/frexp.ll b/llvm/test/Transforms/Scalarizer/frexp.ll
new file mode 100644
index 00000000000000..454042e6887c3a
--- /dev/null
+++ b/llvm/test/Transforms/Scalarizer/frexp.ll
@@ -0,0 +1,66 @@
+; RUN: opt %s -passes='function(scalarizer<load-store>)' -S | FileCheck %s
+
+; CHECK-LABEL: @test_vector_half_frexp_half
+define noundef <2 x half> @test_vector_half_frexp_half(<2 x half> noundef %h) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { half, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { half, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x half> poison, half [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x half> %{{.*}}, half [[ev01]], i64 1
+  %r =  call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+  %e0 = extractvalue { <2 x half>, <2 x i32> } %r, 0
+  ret <2 x half> %e0
+}
+
+; CHECK-LABEL: @test_vector_half_frexp_int
+define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x half> %h, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x half> %h, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { half, i32 } @llvm.frexp.f16.i32(half [[ee1]])
+  ; CHECK-NEXT: [[ev10:%.*]] = extractvalue { half, i32 } [[ie0]], 1
+  ; CHECK-NEXT: [[ev11:%.*]] = extractvalue { half, i32 } [[ie1]], 1
+  ; CHECK-NEXT: insertelement <2 x i32> poison, i32 [[ev10]], i64 0
+  ; CHECK-NEXT: insertelement <2 x i32> %{{.*}}, i32 [[ev11]], i64 1
+  %r =  call { <2 x half>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x half> %h)
+  %e1 = extractvalue { <2 x half>, <2 x i32> } %r, 1
+  ret <2 x i32> %e1
+}
+
+
+define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x float> %f, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { float, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { float, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x float> poison, float [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x float> %{{.*}}, float [[ev01]], i64 1
+  ; CHECK-NEXT: extractvalue { float, i32 } [[ie0]], 1
+  ; CHECK-NEXT: extractvalue { float, i32 } [[ie1]], 1
+  %1 =  call { <2 x float>, <2 x i32> } @llvm.frexp.v2f16.v2i32(<2 x float> %f)
+  %2 = extractvalue { <2 x float>, <2 x i32> } %1, 0
+  %3 = extractvalue { <2 x float>, <2 x i32> } %1, 1
+  ret <2 x float> %2
+}
+
+define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+  ; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])
+  ; CHECK-NEXT: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+  ; CHECK-NEXT: [[ie1:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee1]])
+  ; CHECK-NEXT: [[ev00:%.*]] = extractvalue { double, i32 } [[ie0]], 0
+  ; CHECK-NEXT: [[ev01:%.*]] = extractvalue { double, i32 } [[ie1]], 0
+  ; CHECK-NEXT: insertelement <2 x double> poison, double [[ev00]], i64 0
+  ; CHECK-NEXT: insertelement <2 x double> %{{.*}}, double [[ev01]], i64 1
+  ; CHECK-NEXT: extractvalue { double, i32 } [[ie0]], 1
+  ; CHECK-NEXT: extractvalue { double, i32 } [[ie1]], 1
+  %1 =  call { <2 x double>, <2 x i32> } @llvm.frexp.v2f64.v2i32(<2 x double> %d)
+  %2 = extractvalue { <2 x double>, <2 x i32> } %1, 0
+  %3 = extractvalue { <2 x double>, <2 x i32> } %1, 1
+  ret <2 x double> %2
+}

>From b15c403659b67773fd86174a4463f4fc17cf05f0 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Tue, 8 Oct 2024 03:44:57 -0400
Subject: [PATCH 3/3] fix up in prep for PR.

---
 llvm/lib/Transforms/Scalar/Scalarizer.cpp | 41 +++++++++++++----------
 llvm/test/CodeGen/DirectX/split-double.ll | 14 ++++++--
 llvm/test/Transforms/Scalarizer/frexp.ll  |  3 +-
 3 files changed, 37 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index f6a7230a472de5..d8b052061c1ad5 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -200,11 +200,17 @@ struct VectorLayout {
 static bool isStructAllVectors(Type *Ty) {
   if (!isa<StructType>(Ty))
     return false;
-
-  for(unsigned I = 0; I < Ty->getNumContainedTypes(); I++)
-    if (!isa<FixedVectorType>(Ty->getContainedType(I)))
+  if (Ty->getNumContainedTypes() < 1)
+    return false;
+  FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
+  if (!VecTy)
+    return false;
+  unsigned VecSize = VecTy->getNumElements();
+  for (unsigned I = 1; I < Ty->getNumContainedTypes(); I++) {
+    VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
+    if (!VecTy || VecSize != VecTy->getNumElements())
       return false;
-
+  }
   return true;
 }
 
@@ -679,8 +685,9 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
 bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
   if (isTriviallyVectorizable(ID))
     return true;
+  // TODO: investigate vectorizable frexp
   switch (ID) {
-    case Intrinsic::frexp:
+  case Intrinsic::frexp:
     return true;
   }
   return Intrinsic::isTargetIntrinsic(ID) &&
@@ -690,10 +697,10 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
-  Type* CallType = CI.getType();
-  bool areAllVectors = isStructAllVectors(CallType);
-   std::optional<VectorSplit> VS;
-  if (areAllVectors)
+  Type *CallType = CI.getType();
+  bool AreAllVectors = isStructAllVectors(CallType);
+  std::optional<VectorSplit> VS;
+  if (AreAllVectors)
     VS = getVectorSplit(CallType->getContainedType(0));
   else
     VS = getVectorSplit(CallType);
@@ -721,12 +728,12 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
     Tys.push_back(VS->SplitTy);
 
-  if(areAllVectors) {
-    Type* PrevType = CallType->getContainedType(0);
-    Type* CallType = CI.getType();
-    for(unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
-      Type* CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
-      if(PrevType != CurrType) {
+  if (AreAllVectors) {
+    Type *PrevType = CallType->getContainedType(0);
+    Type *CallType = CI.getType();
+    for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
+      Type *CurrType = cast<FixedVectorType>(CallType->getContainedType(I));
+      if (PrevType != CurrType) {
         std::optional<VectorSplit> CurrVS = getVectorSplit(CurrType);
         Tys.push_back(CurrVS->SplitTy);
         PrevType = CurrType;
@@ -1069,7 +1076,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
   ValueVector Res;
   if (!isStructAllVectors(OpTy))
     return false;
-  Type* VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
+  Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
   std::optional<VectorSplit> VS = getVectorSplit(VecType);
   if (!VS)
     return false;
@@ -1083,7 +1090,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
         Op0[OpIdx], Index, EVI.getName() + ".elem" + std::to_string(Index));
     Res.push_back(ResElem);
   }
-  // replaceUses(&EVI, Res);
+
   gather(&EVI, Res, *VS);
   return true;
 }
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
index 4fc5fdd1922a2c..9b70e87ba4794e 100644
--- a/llvm/test/CodeGen/DirectX/split-double.ll
+++ b/llvm/test/CodeGen/DirectX/split-double.ll
@@ -1,10 +1,18 @@
-; RUN: opt -S -scalarizer  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -passes='function(scalarizer<load-store>)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-define void @test_vector_double_split_void(<3 x double> noundef %d) {
-  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+; CHECK-LABEL: @test_vector_double_split_void
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+  ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
+  ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <2 x double> %d, i64 1
+  ; CHECK: [[ie1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee1]])
+  ; CHECK-NOT: extractvalue { i32, i32 } {{.*}}, 0
+  ; CHECK-NOT: insertelement <2 x i32> {{.*}}, i32 {{.*}}, i64 0
+  %hlsl.asuint = call { <2 x i32>, <2 x i32> }  @llvm.dx.splitdouble.v2i32(<2 x double> %d)
   ret void
 }
 
+; CHECK-LABEL: @test_vector_double_split
 define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
   ; CHECK: [[ee0:%.*]] = extractelement <3 x double> %d, i64 0
   ; CHECK: [[ie0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[ee0]])
diff --git a/llvm/test/Transforms/Scalarizer/frexp.ll b/llvm/test/Transforms/Scalarizer/frexp.ll
index 454042e6887c3a..48159b45c18960 100644
--- a/llvm/test/Transforms/Scalarizer/frexp.ll
+++ b/llvm/test/Transforms/Scalarizer/frexp.ll
@@ -30,7 +30,7 @@ define noundef <2 x i32> @test_vector_half_frexp_int(<2 x half> noundef %h) {
   ret <2 x i32> %e1
 }
 
-
+; CHECK-LABEL: @test_vector_float_frexp_int
 define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f) {
   ; CHECK: [[ee0:%.*]] = extractelement <2 x float> %f, i64 0
   ; CHECK-NEXT: [[ie0:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[ee0]])
@@ -48,6 +48,7 @@ define noundef <2 x float> @test_vector_float_frexp_int(<2 x float> noundef %f)
   ret <2 x float> %2
 }
 
+; CHECK-LABEL: @test_vector_double_frexp_int
 define noundef <2 x double> @test_vector_double_frexp_int(<2 x double> noundef %d) {
   ; CHECK: [[ee0:%.*]] = extractelement <2 x double> %d, i64 0
   ; CHECK-NEXT: [[ie0:%.*]] = call { double, i32 } @llvm.frexp.f64.i32(double [[ee0]])



More information about the llvm-commits mailing list