[llvm] Complex deinterleaving/single reductions build fixReapply "Add support for single reductions in ComplexDeinterleavingPass (#112875)" (PR #120441)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 07:29:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

<details>
<summary>Changes</summary>

This reverts commit 76714be5fd4ace66dd9e19ce706c2e2149dd5716, fixing the build failure that caused the revert.

---

Patch is 121.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120441.diff


4 Files Affected:

- (modified) llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h (+2) 
- (modified) llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (+271-14) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+32-11) 
- (added) llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll (+1136) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
index 84a2673fecb5bf..4383249658e606 100644
--- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
+++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
@@ -35,6 +35,7 @@ struct ComplexDeinterleavingPass
 enum class ComplexDeinterleavingOperation {
   CAdd,
   CMulPartial,
+  CDot,
   // The following 'operations' are used to represent internal states. Backends
   // are not expected to try and support these in any capacity.
   Deinterleave,
@@ -43,6 +44,7 @@ enum class ComplexDeinterleavingOperation {
   ReductionPHI,
   ReductionOperation,
   ReductionSelect,
+  ReductionSingle
 };
 
 enum class ComplexDeinterleavingRotation {
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index f3f7ea9407b46f..603782a1b2f937 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -108,6 +108,13 @@ static bool isNeg(Value *V);
 static Value *getNegOperand(Value *V);
 
 namespace {
+template <typename T, typename IterT>
+std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
+  auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
+  if (Common != A.end())
+    return std::make_optional(*Common);
+  return std::nullopt;
+}
 
 class ComplexDeinterleavingLegacyPass : public FunctionPass {
 public:
@@ -144,6 +151,7 @@ struct ComplexDeinterleavingCompositeNode {
   friend class ComplexDeinterleavingGraph;
   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
+  bool OperandsValid = true;
 
 public:
   ComplexDeinterleavingOperation Operation;
@@ -160,7 +168,11 @@ struct ComplexDeinterleavingCompositeNode {
   SmallVector<RawNodePtr> Operands;
   Value *ReplacementNode = nullptr;
 
-  void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
+  void addOperand(NodePtr Node) {
+    if (!Node || !Node.get())
+      OperandsValid = false;
+    Operands.push_back(Node.get());
+  }
 
   void dump() { dump(dbgs()); }
   void dump(raw_ostream &OS) {
@@ -194,6 +206,8 @@ struct ComplexDeinterleavingCompositeNode {
       PrintNodeRef(Op);
     }
   }
+
+  bool areOperandsValid() { return OperandsValid; }
 };
 
 class ComplexDeinterleavingGraph {
@@ -293,7 +307,7 @@ class ComplexDeinterleavingGraph {
 
   NodePtr submitCompositeNode(NodePtr Node) {
     CompositeNodes.push_back(Node);
-    if (Node->Real && Node->Imag)
+    if (Node->Real)
       CachedResult[{Node->Real, Node->Imag}] = Node;
     return Node;
   }
@@ -327,6 +341,8 @@ class ComplexDeinterleavingGraph {
   ///      i: ai - br
   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
+  NodePtr identifyPartialReduction(Value *R, Value *I);
+  NodePtr identifyDotProduct(Value *Inst);
 
   NodePtr identifyNode(Value *R, Value *I);
 
@@ -396,6 +412,7 @@ class ComplexDeinterleavingGraph {
   /// * Deinterleave the final value outside of the loop and repurpose original
   /// reduction users
   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
+  void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
 
 public:
   void dump() { dump(dbgs()); }
@@ -891,17 +908,163 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
 }
 
 ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
-  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
-  assert(R->getType() == I->getType() &&
-         "Real and imaginary parts should not have different types");
+ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
+
+  if (!TL->isComplexDeinterleavingOperationSupported(
+          ComplexDeinterleavingOperation::CDot, V->getType())) {
+    LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
+                         "operation CDot with the type "
+                      << *V->getType() << "\n");
+    return nullptr;
+  }
+
+  auto *Inst = cast<Instruction>(V);
+  auto *RealUser = cast<Instruction>(*Inst->user_begin());
+
+  NodePtr CN =
+      prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
+
+  NodePtr ANode;
+
+  const Intrinsic::ID PartialReduceInt =
+      Intrinsic::experimental_vector_partial_reduce_add;
+
+  Value *AReal = nullptr;
+  Value *AImag = nullptr;
+  Value *BReal = nullptr;
+  Value *BImag = nullptr;
+  Value *Phi = nullptr;
+
+  auto UnwrapCast = [](Value *V) -> Value * {
+    if (auto *CI = dyn_cast<CastInst>(V))
+      return CI->getOperand(0);
+    return V;
+  };
+
+  auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                    m_Mul(m_Value(BReal), m_Value(AReal))),
+      m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
+
+  auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(
+          m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
+      m_Mul(m_Value(BImag), m_Value(AReal)));
+
+  if (match(Inst, PatternRot0)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
+  } else if (match(Inst, PatternRot270)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
+  } else {
+    Value *A0, *A1;
+    // The rotations 90 and 180 share the same operation pattern, so inspect the
+    // order of the operands, identifying where the real and imaginary
+    // components of A go, to discern between the aforementioned rotations.
+    auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
+        m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                      m_Mul(m_Value(BReal), m_Value(A0))),
+        m_Mul(m_Value(BImag), m_Value(A1)));
+
+    if (!match(Inst, PatternRot90Rot180))
+      return nullptr;
+
+    A0 = UnwrapCast(A0);
+    A1 = UnwrapCast(A1);
+
+    // Test if A0 is real/A1 is imag
+    ANode = identifyNode(A0, A1);
+    if (!ANode) {
+      // Test if A0 is imag/A1 is real
+      ANode = identifyNode(A1, A0);
+      // Unable to identify operand components, thus unable to identify rotation
+      if (!ANode)
+        return nullptr;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
+      AReal = A1;
+      AImag = A0;
+    } else {
+      AReal = A0;
+      AImag = A1;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
+    }
+  }
+
+  AReal = UnwrapCast(AReal);
+  AImag = UnwrapCast(AImag);
+  BReal = UnwrapCast(BReal);
+  BImag = UnwrapCast(BImag);
+
+  VectorType *VTy = cast<VectorType>(V->getType());
+  Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
+  if (AReal->getType() != ExpectedOperandTy)
+    return nullptr;
+  if (AImag->getType() != ExpectedOperandTy)
+    return nullptr;
+  if (BReal->getType() != ExpectedOperandTy)
+    return nullptr;
+  if (BImag->getType() != ExpectedOperandTy)
+    return nullptr;
+
+  if (Phi->getType() != VTy && RealUser->getType() != VTy)
+    return nullptr;
+
+  NodePtr Node = identifyNode(AReal, AImag);
+
+  // In the case that a node was identified to figure out the rotation, ensure
+  // that trying to identify a node with AReal and AImag post-unwrap results in
+  // the same node
+  if (ANode && Node != ANode) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Identified node is different from previously identified node. "
+           "Unable to confidently generate a complex operation node\n");
+    return nullptr;
+  }
+
+  CN->addOperand(Node);
+  CN->addOperand(identifyNode(BReal, BImag));
+  CN->addOperand(identifyNode(Phi, RealUser));
+
+  return submitCompositeNode(CN);
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
+  // Partial reductions don't support non-vector types, so check these first
+  if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
+    return nullptr;
+
+  auto CommonUser =
+      findCommonBetweenCollections<Value *>(R->users(), I->users());
+  if (!CommonUser)
+    return nullptr;
+
+  auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
+  if (!IInst || IInst->getIntrinsicID() !=
+                    Intrinsic::experimental_vector_partial_reduce_add)
+    return nullptr;
+
+  if (NodePtr CN = identifyDotProduct(IInst))
+    return CN;
+
+  return nullptr;
+}
 
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
   auto It = CachedResult.find({R, I});
   if (It != CachedResult.end()) {
     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
     return It->second;
   }
 
+  if (NodePtr CN = identifyPartialReduction(R, I))
+    return CN;
+
+  bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
+  if (!IsReduction && R->getType() != I->getType())
+    return nullptr;
+
   if (NodePtr CN = identifySplat(R, I))
     return CN;
 
@@ -1427,12 +1590,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
   if (It != RootToNode.end()) {
     auto RootNode = It->second;
     assert(RootNode->Operation ==
-           ComplexDeinterleavingOperation::ReductionOperation);
+               ComplexDeinterleavingOperation::ReductionOperation ||
+           RootNode->Operation ==
+               ComplexDeinterleavingOperation::ReductionSingle);
     // Find out which part, Real or Imag, comes later, and only if we come to
     // the latest part, add it to OrderedRoots.
     auto *R = cast<Instruction>(RootNode->Real);
-    auto *I = cast<Instruction>(RootNode->Imag);
-    auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
+    auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
+
+    Instruction *ReplacementAnchor;
+    if (I)
+      ReplacementAnchor = R->comesBefore(I) ? I : R;
+    else
+      ReplacementAnchor = R;
+
     if (ReplacementAnchor != RootI)
       return false;
     OrderedRoots.push_back(RootI);
@@ -1523,7 +1694,6 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
       if (Processed[j])
         continue;
-
       auto *Real = OperationInstruction[i];
       auto *Imag = OperationInstruction[j];
       if (Real->getType() != Imag->getType())
@@ -1556,6 +1726,28 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
         break;
       }
     }
+
+    auto *Real = OperationInstruction[i];
+    // We want to check that we have 2 operands, but the function attributes
+    // being counted as operands bloats this value.
+    if (Real->getNumOperands() < 2)
+      continue;
+
+    RealPHI = ReductionInfo[Real].first;
+    ImagPHI = nullptr;
+    PHIsFound = false;
+    auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
+    if (Node && PHIsFound) {
+      LLVM_DEBUG(
+          dbgs() << "Identified single reduction starting from instruction: "
+                 << *Real << "/" << *ReductionInfo[Real].second << "\n");
+      Processed[i] = true;
+      auto RootNode = prepareCompositeNode(
+          ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
+      RootNode->addOperand(Node);
+      RootToNode[Real] = RootNode;
+      submitCompositeNode(RootNode);
+    }
   }
 
   RealPHI = nullptr;
@@ -1563,6 +1755,19 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
 }
 
 bool ComplexDeinterleavingGraph::checkNodes() {
+
+  bool FoundDeinterleaveNode = false;
+  for (NodePtr N : CompositeNodes) {
+    if (!N->areOperandsValid())
+      return false;
+    if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
+      FoundDeinterleaveNode = true;
+  }
+
+  // We need a deinterleave node in order to guarantee that we're working with complex numbers.
+  if (!FoundDeinterleaveNode)
+    return false;
+
   // Collect all instructions from roots to leaves
   SmallPtrSet<Instruction *, 16> AllInstructions;
   SmallVector<Instruction *, 8> Worklist;
@@ -1831,7 +2036,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
 ComplexDeinterleavingGraph::NodePtr
 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                             Instruction *Imag) {
-  if (Real != RealPHI || Imag != ImagPHI)
+  if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
     return nullptr;
 
   PHIsFound = true;
@@ -1926,6 +2131,16 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
 
   Value *ReplacementNode;
   switch (Node->Operation) {
+  case ComplexDeinterleavingOperation::CDot: {
+    Value *Input0 = ReplaceOperandIfExist(Node, 0);
+    Value *Input1 = ReplaceOperandIfExist(Node, 1);
+    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
+    assert(!Input1 || (Input0->getType() == Input1->getType() &&
+                       "Node inputs need to be of the same type"));
+    ReplacementNode = TL->createComplexDeinterleavingIR(
+        Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
+    break;
+  }
   case ComplexDeinterleavingOperation::CAdd:
   case ComplexDeinterleavingOperation::CMulPartial:
   case ComplexDeinterleavingOperation::Symmetric: {
@@ -1969,13 +2184,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
   case ComplexDeinterleavingOperation::ReductionPHI: {
     // If Operation is ReductionPHI, a new empty PHINode is created.
     // It is filled later when the ReductionOperation is processed.
+    auto *OldPHI = cast<PHINode>(Node->Real);
     auto *VTy = cast<VectorType>(Node->Real->getType());
     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
-    OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
+    OldToNewPHI[OldPHI] = NewPHI;
     ReplacementNode = NewPHI;
     break;
   }
+  case ComplexDeinterleavingOperation::ReductionSingle:
+    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
+    processReductionSingle(ReplacementNode, Node);
+    break;
   case ComplexDeinterleavingOperation::ReductionOperation:
     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
     processReductionOperation(ReplacementNode, Node);
@@ -2000,6 +2220,38 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
   return ReplacementNode;
 }
 
+void ComplexDeinterleavingGraph::processReductionSingle(
+    Value *OperationReplacement, RawNodePtr Node) {
+  auto *Real = cast<Instruction>(Node->Real);
+  auto *OldPHI = ReductionInfo[Real].first;
+  auto *NewPHI = OldToNewPHI[OldPHI];
+  auto *VTy = cast<VectorType>(Real->getType());
+  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+
+  Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
+
+  IRBuilder<> Builder(Incoming->getTerminator());
+
+  Value *NewInit = nullptr;
+  if (auto *C = dyn_cast<Constant>(Init)) {
+    if (C->isZeroValue())
+      NewInit = Constant::getNullValue(NewVTy);
+  }
+
+  if (!NewInit)
+    NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
+                                      {Init, Constant::getNullValue(VTy)});
+
+  NewPHI->addIncoming(NewInit, Incoming);
+  NewPHI->addIncoming(OperationReplacement, BackEdge);
+
+  auto *FinalReduction = ReductionInfo[Real].second;
+  Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
+
+  auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
+  FinalReduction->replaceAllUsesWith(AddReduce);
+}
+
 void ComplexDeinterleavingGraph::processReductionOperation(
     Value *OperationReplacement, RawNodePtr Node) {
   auto *Real = cast<Instruction>(Node->Real);
@@ -2059,8 +2311,13 @@ void ComplexDeinterleavingGraph::replaceNodes() {
       auto *RootImag = cast<Instruction>(RootNode->Imag);
       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
-      DeadInstrRoots.push_back(cast<Instruction>(RootReal));
-      DeadInstrRoots.push_back(cast<Instruction>(RootImag));
+      DeadInstrRoots.push_back(RootReal);
+      DeadInstrRoots.push_back(RootImag);
+    } else if (RootNode->Operation ==
+               ComplexDeinterleavingOperation::ReductionSingle) {
+      auto *RootInst = cast<Instruction>(RootNode->Real);
+      ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
+      DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
     } else {
       assert(R && "Unable to find replacement for RootInstruction");
       DeadInstrRoots.push_back(RootInstruction);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cb6ba06bd4425c..d45c3cddd64de4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29542,9 +29542,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
 
   if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
     unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
+
+    if (Operation == ComplexDeinterleavingOperation::CDot)
+      return ScalarWidth == 32 || ScalarWidth == 64;
     return 8 <= ScalarWidth && ScalarWidth <= 64;
   }
 
+  // CDot is not supported outside of scalable/sve scopes
+  if (Operation == ComplexDeinterleavingOperation::CDot)
+    return false;
+
   return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
          ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
 }
@@ -29554,6 +29561,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
     ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
     Value *Accumulator) const {
   VectorType *Ty = cast<VectorType>(InputA->getType());
+  if (Accumulator == nullptr)
+    Accumulator = Constant::getNullValue(Ty);
   bool IsScalable = Ty->isScalableTy();
   bool IsInt = Ty->getElementType()->isIntegerTy();
 
@@ -29565,6 +29574,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
 
   if (TyWidth > 128) {
     int Stride = Ty->getElementCount().getKnownMinValue() / 2;
+    int AccStride = cast<VectorType>(Accumulator->getType())
+                        ->getElementCount()
+                        .getKnownMinValue() /
+                    2;
     auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
     auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
     auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29574,25 +29587,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
         B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
     Value *LowerSplitAcc = nullptr;
     Value *UpperSplitAcc = nullptr;
-    if (Accumulator) {
-      LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
-      UpperSplitAcc =
-          B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
-    }
+    Type *FullTy = Ty;
+    FullTy = Accumulator->getType();
+    auto *HalfAccTy = VectorType::getHalfElementsVectorType(
+        cast<VectorType>(Accumulator->getType()));
+    LowerSplitAcc =
+        B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
+    UpperSplitAcc =
+        B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
     auto *LowerSplitInt = createComplexDeinterleavingIR(
         B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
     auto *UpperSplitInt = createComplexDeinterleavingIR(
         B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
 
-    auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
-                                        B.getInt64(0));
-    return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
+    auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
+                                        LowerSplitInt, B.getInt64(0));
+    return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
+                                B.getInt64(AccStride));
   }
 
   if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
-    if (Accumulator == nullptr)
-      Accumulator = Constant::getNullValue(Ty);
-
     if (IsScalable...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/120441


More information about the llvm-commits mailing list