[llvm] [CodeGen] Expansion of scalable vector reductions (PR #129214)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 6 01:27:23 PST 2025


https://github.com/iamlouk updated https://github.com/llvm/llvm-project/pull/129214

>From 6cfeeeb9b4c13de92813c0e24de9c6132e57fa74 Mon Sep 17 00:00:00 2001
From: Lou Knauer <lou.knauer at sipearl.com>
Date: Fri, 28 Feb 2025 12:30:59 +0100
Subject: [PATCH 1/2] [CodeGen] Expansion of scalable vector reductions

Add support for the expansion of reductions of scalable vectors in the
ExpandReductionsPass. This is motivated, for example, by the fact that
SVE does not have product/multiply reductions.

Two expansion techniques are implemented, one for a parallel-tree like
reduction if re-association is allowed (only if VScale is a known power
of two for now), and a sequential one. In case the vscale is a compile-
time constant (`vscale_range` function attribute), no loop is generated.

Note that the loop-vectorizer will not generate product reductions for
scalable vectors even with this patch as the TTI still disallowes this
and returns invalid costs. A follow-up MR could then allow product
reductions and return a high but not invalid cost.
---
 llvm/lib/CodeGen/ExpandReductions.cpp         | 240 +++++++++++++++++-
 .../AArch64/AArch64TargetTransformInfo.h      |  11 +-
 llvm/lib/Transforms/Utils/LoopUtils.cpp       |  15 +-
 .../AArch64/expand-scalable-reductions.ll     | 138 ++++++++++
 4 files changed, 392 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/expand-scalable-reductions.ll

diff --git a/llvm/lib/CodeGen/ExpandReductions.cpp b/llvm/lib/CodeGen/ExpandReductions.cpp
index d6778ec666cbe..797c83aea3309 100644
--- a/llvm/lib/CodeGen/ExpandReductions.cpp
+++ b/llvm/lib/CodeGen/ExpandReductions.cpp
@@ -12,22 +12,190 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/ExpandReductions.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
+#include <optional>
 
 using namespace llvm;
 
 namespace {
 
-bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
-  bool Changed = false;
+void updateDomTreeForScalableExpansion(DominatorTree *DT, BasicBlock *Preheader,
+                                       BasicBlock *Loop, BasicBlock *Exit) {
+  DT->addNewBlock(Loop, Preheader);
+  DT->changeImmediateDominator(Exit, Loop);
+  assert(DT->verify(DominatorTree::VerificationLevel::Fast));
+}
+
+/// Expand a reduction on a scalable vector into a loop
+/// that iterates over one element after the other.
+Value *expandScalableReduction(IRBuilderBase &Builder, IntrinsicInst *II,
+                               Value *Acc, Value *Vec,
+                               Instruction::BinaryOps BinOp,
+                               DominatorTree *DT) {
+  ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
+
+  // Split the original BB in two and create a new BB between them,
+  // which will be a loop.
+  BasicBlock *BeforeBB = II->getParent();
+  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
+  BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
+                                          BeforeBB->getParent(), AfterBB);
+  BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
+
+  // Calculate the number of elements in the vector:
+  Builder.SetInsertPoint(BeforeBB->getTerminator());
+  Value *NumElts =
+      Builder.CreateVScale(Builder.getInt64(VecTy->getMinNumElements()));
+
+  // Create two PHIs, one for the index of the current lane and one for
+  // the reduction.
+  Builder.SetInsertPoint(LoopBB);
+  PHINode *IV = Builder.CreatePHI(Builder.getInt64Ty(), 2, "index");
+  IV->addIncoming(Builder.getInt64(0), BeforeBB);
+  PHINode *RdxPhi = Builder.CreatePHI(VecTy->getScalarType(), 2, "rdx.phi");
+  RdxPhi->addIncoming(Acc, BeforeBB);
+
+  Value *IVInc =
+      Builder.CreateAdd(IV, Builder.getInt64(1), "index.next", true, true);
+  IV->addIncoming(IVInc, LoopBB);
+
+  // Extract the value at the current lane from the vector and perform
+  // the scalar reduction binop:
+  Value *Lane = Builder.CreateExtractElement(Vec, IV, "elm");
+  Value *Rdx = Builder.CreateBinOp(BinOp, RdxPhi, Lane, "rdx");
+  RdxPhi->addIncoming(Rdx, LoopBB);
+
+  // Exit when all lanes have been treated (assuming there will be at least
+  // one element in the vector):
+  Value *Done = Builder.CreateCmp(CmpInst::ICMP_EQ, IVInc, NumElts, "exitcond");
+  Builder.CreateCondBr(Done, AfterBB, LoopBB);
+
+  if (DT)
+    updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
+
+  return Rdx;
+}
+
+/// Expand a reduction on a scalable vector in a parallel-tree like
+/// manner, meaning halving the number of elements to treat in every
+/// iteration.
+Value *expandScalableTreeReduction(
+    IRBuilderBase &Builder, IntrinsicInst *II, std::optional<Value *> Acc,
+    Value *Vec, Instruction::BinaryOps BinOp,
+    function_ref<bool(Constant *)> IsNeutralElement, DominatorTree *DT,
+    std::optional<unsigned> FixedVScale) {
+  ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
+  ScalableVectorType *VecTyX2 = ScalableVectorType::get(
+      VecTy->getScalarType(), VecTy->getMinNumElements() * 2);
+
+  // If the VScale is fixed, do not generate a loop, and instead to
+  // something similar to llvm::getShuffleReduction(). That function
+  // cannot be used directly because it uses shuffle masks, which
+  // are not avaiable for scalable vectors (even if vscale is fixed).
+  // The approach is effectively the same.
+  if (FixedVScale.has_value()) {
+    unsigned VF = VecTy->getMinNumElements() * FixedVScale.value();
+    assert(isPowerOf2_64(VF));
+    for (unsigned I = VF; I != 1; I >>= 1) {
+      Value *Extended = Builder.CreateInsertVector(
+          VecTyX2, PoisonValue::get(VecTyX2), Vec, Builder.getInt64(0));
+      Value *Pair = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
+                                            {VecTyX2}, {Extended});
+
+      Value *Vec1 = Builder.CreateExtractValue(Pair, {0});
+      Value *Vec2 = Builder.CreateExtractValue(Pair, {1});
+      Vec = Builder.CreateBinOp(BinOp, Vec1, Vec2, "rdx");
+    }
+    Value *FinalVal = Builder.CreateExtractElement(Vec, uint64_t(0));
+    if (Acc)
+      if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement(C))
+        FinalVal = Builder.CreateBinOp(BinOp, *Acc, FinalVal, "rdx.final");
+    return FinalVal;
+  }
+
+  // Split the original BB in two and create a new BB between them,
+  // which will be a loop.
+  BasicBlock *BeforeBB = II->getParent();
+  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
+  BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
+                                          BeforeBB->getParent(), AfterBB);
+  BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
+
+  // This tree reduction only needs to do log2(N) iterations.
+  // Note: Calculating log2(N) using count-trailing-zeros (cttz) only works if
+  // `vscale` the vector size is a power of two.
+  Builder.SetInsertPoint(BeforeBB->getTerminator());
+  Value *NumElts =
+      Builder.CreateVScale(Builder.getInt64(VecTy->getMinNumElements()));
+  Value *NumIters = Builder.CreateIntrinsic(NumElts->getType(), Intrinsic::cttz,
+                                            {NumElts, Builder.getTrue()});
+
+  // Create two PHIs, one for the IV and one for the reduction.
+  Builder.SetInsertPoint(LoopBB);
+  PHINode *IV = Builder.CreatePHI(Builder.getInt64Ty(), 2, "iter");
+  IV->addIncoming(Builder.getInt64(0), BeforeBB);
+  PHINode *VecPhi = Builder.CreatePHI(VecTy, 2, "rdx.phi");
+  VecPhi->addIncoming(Vec, BeforeBB);
+
+  Value *IVInc =
+      Builder.CreateAdd(IV, Builder.getInt64(1), "iter.next", true, true);
+  IV->addIncoming(IVInc, LoopBB);
+
+  // The deinterleave intrinsic takes a vector of, for example, type
+  // <vscale x 8 x float> and produces a pair of vectors with half the size,
+  // so 2 x <vscale x 4 x float>. An insert vector operation is used to
+  // create a double-sized vector where the upper half is poison, because
+  // we never care about that upper half anyways!
+  Value *Extended = Builder.CreateInsertVector(
+      VecTyX2, PoisonValue::get(VecTyX2), VecPhi, Builder.getInt64(0));
+  Value *Pair = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
+                                        {VecTyX2}, {Extended});
+  Value *Vec1 = Builder.CreateExtractValue(Pair, {0});
+  Value *Vec2 = Builder.CreateExtractValue(Pair, {1});
+  Value *Rdx = Builder.CreateBinOp(BinOp, Vec1, Vec2, "rdx");
+  VecPhi->addIncoming(Rdx, LoopBB);
+
+  // Reduction-loop exit condition:
+  Value *Done =
+      Builder.CreateCmp(CmpInst::ICMP_EQ, IVInc, NumIters, "exitcond");
+  Builder.CreateCondBr(Done, AfterBB, LoopBB);
+  Builder.SetInsertPoint(AfterBB, AfterBB->getFirstInsertionPt());
+  Value *FinalVal = Builder.CreateExtractElement(Rdx, uint64_t(0));
+
+  // If the Acc value is not the neutral element of the reduction operation,
+  // then we need to do the binop one last time with the end result of the
+  // tree reduction.
+  if (Acc)
+    if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement(C))
+      FinalVal = Builder.CreateBinOp(BinOp, *Acc, FinalVal, "rdx.final");
+
+  if (DT)
+    updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
+
+  return FinalVal;
+}
+
+std::pair<bool, bool> expandReductions(Function &F,
+                                       const TargetTransformInfo *TTI,
+                                       DominatorTree *DT) {
+  bool Changed = false, CFGChanged = false;
   SmallVector<IntrinsicInst *, 4> Worklist;
   for (auto &I : instructions(F)) {
     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
@@ -54,6 +222,12 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
     }
   }
 
+  const auto &Attrs = F.getAttributes().getFnAttrs();
+  unsigned MinVScale = Attrs.getVScaleRangeMin();
+  std::optional<unsigned> FixedVScale = Attrs.getVScaleRangeMax();
+  if (FixedVScale != MinVScale)
+    FixedVScale = std::nullopt;
+
   for (auto *II : Worklist) {
     FastMathFlags FMF =
         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
@@ -74,7 +248,34 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
       // and it can't be handled by generating a shuffle sequence.
       Value *Acc = II->getArgOperand(0);
       Value *Vec = II->getArgOperand(1);
-      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+      auto RdxOpcode =
+          Instruction::BinaryOps(getArithmeticReductionInstruction(ID));
+
+      bool ScalableTy = Vec->getType()->isScalableTy();
+      if (ScalableTy && (!FixedVScale || FMF.allowReassoc())) {
+        CFGChanged |= !FixedVScale;
+        assert(TTI->isVScaleKnownToBeAPowerOfTwo() &&
+               "Scalable tree reduction unimplemented for targets with a "
+               "VScale not known to be a power of 2.");
+        if (FMF.allowReassoc())
+          Rdx = expandScalableTreeReduction(
+              Builder, II, Acc, Vec, RdxOpcode,
+              [&](Constant *C) {
+                switch (ID) {
+                case Intrinsic::vector_reduce_fadd:
+                  return C->isZeroValue();
+                case Intrinsic::vector_reduce_fmul:
+                  return C->isOneValue();
+                default:
+                  llvm_unreachable("Binop not handled");
+                }
+              },
+              DT, FixedVScale);
+        else
+          Rdx = expandScalableReduction(Builder, II, Acc, Vec, RdxOpcode, DT);
+        break;
+      }
+
       if (!FMF.allowReassoc())
         Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
       else {
@@ -125,10 +326,22 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
     case Intrinsic::vector_reduce_umax:
     case Intrinsic::vector_reduce_umin: {
       Value *Vec = II->getArgOperand(0);
+      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
+      if (Vec->getType()->isScalableTy()) {
+        CFGChanged |= !FixedVScale;
+        assert(TTI->isVScaleKnownToBeAPowerOfTwo() &&
+               "Scalable tree reduction unimplemented for targets with a "
+               "VScale not known to be a power of 2.");
+        Rdx = expandScalableTreeReduction(
+            Builder, II, std::nullopt, Vec, Instruction::BinaryOps(RdxOpcode),
+            [](Constant *C) -> bool { llvm_unreachable("No accumulator!"); },
+            DT, FixedVScale);
+        break;
+      }
+
       if (!isPowerOf2_32(
               cast<FixedVectorType>(Vec->getType())->getNumElements()))
         continue;
-      unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
       Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
       break;
     }
@@ -150,7 +363,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
     II->eraseFromParent();
     Changed = true;
   }
-  return Changed;
+  return {CFGChanged, Changed};
 }
 
 class ExpandReductions : public FunctionPass {
@@ -161,13 +374,15 @@ class ExpandReductions : public FunctionPass {
   }
 
   bool runOnFunction(Function &F) override {
-    const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
-    return expandReductions(F, TTI);
+    const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+    auto *DTA = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+    return expandReductions(F, TTI, DTA ? &DTA->getDomTree() : nullptr).second;
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<TargetTransformInfoWrapperPass>();
-    AU.setPreservesCFG();
+    AU.addUsedIfAvailable<DominatorTreeWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
   }
 };
 }
@@ -186,9 +401,14 @@ FunctionPass *llvm::createExpandReductionsPass() {
 PreservedAnalyses ExpandReductionsPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
-  if (!expandReductions(F, &TTI))
+  auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
+  auto [CFGChanged, Changed] = expandReductions(F, &TTI, DT);
+  if (!Changed)
     return PreservedAnalyses::all();
   PreservedAnalyses PA;
-  PA.preserveSet<CFGAnalyses>();
+  if (!CFGChanged)
+    PA.preserveSet<CFGAnalyses>();
+  else
+    PA.preserve<DominatorTreeAnalysis>();
   return PA;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 8a3fd11705640..503b93d0824ef 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -382,7 +382,16 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   shouldConsiderAddressTypePromotion(const Instruction &I,
                                      bool &AllowPromotionWithoutCommonHeader);
 
-  bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
+  bool shouldExpandReduction(const IntrinsicInst *II) const {
+    switch (II->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_mul:
+      return II->getOperand(0)->getType()->isScalableTy();
+    case Intrinsic::vector_reduce_fmul:
+      return II->getOperand(1)->getType()->isScalableTy();
+    default:
+      return false;
+    }
+  }
 
   unsigned getGISelRematGlobalCost() const {
     return 2;
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 42c70d2c163b5..73d1f9fa5642b 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1114,10 +1114,23 @@ Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
   return Select;
 }
 
+static unsigned getFixedVF(Function *F, Type *Ty) {
+  if (auto *Fixed = dyn_cast<FixedVectorType>(Ty))
+    return Fixed->getNumElements();
+
+  auto *ScalableTy = cast<ScalableVectorType>(Ty);
+  unsigned VScaleMin = F->getAttributes().getFnAttrs().getVScaleRangeMin();
+  assert(F->getAttributes().getFnAttrs().getVScaleRangeMax() == VScaleMin &&
+         "Expected a compile-time known VScale");
+
+  return ScalableTy->getMinNumElements() * VScaleMin;
+}
+
 // Helper to generate an ordered reduction.
 Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
                                  unsigned Op, RecurKind RdxKind) {
-  unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements();
+  unsigned VF =
+      getFixedVF(Builder.GetInsertBlock()->getParent(), Src->getType());
 
   // Extract and apply reduction ops in ascending order:
   // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
diff --git a/llvm/test/CodeGen/AArch64/expand-scalable-reductions.ll b/llvm/test/CodeGen/AArch64/expand-scalable-reductions.ll
new file mode 100644
index 0000000000000..9c732dc74d542
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/expand-scalable-reductions.ll
@@ -0,0 +1,138 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -expand-reductions -S | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; The reduction has the reassoc fast-math flag, so it can be done in log2(VF) iterations.
+define float @test_reduce_fmul_tree_expansion(<vscale x 4 x float> %vec) #0 {
+; CHECK-LABEL: define float @test_reduce_fmul_tree_expansion(
+; CHECK-SAME: <vscale x 4 x float> [[VEC:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.cttz.i64(i64 [[TMP2]], i1 true)
+; CHECK-NEXT:    br label %[[RDX_LOOP:.*]]
+; CHECK:       [[RDX_LOOP]]:
+; CHECK-NEXT:    [[ITER:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[ITER_NEXT:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[RDX_PHI:%.*]] = phi fast <vscale x 4 x float> [ [[VEC]], [[TMP0]] ], [ [[RDX:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[ITER_NEXT]] = add nuw nsw i64 [[ITER]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = call fast <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> poison, <vscale x 4 x float> [[RDX_PHI]], i64 0)
+; CHECK-NEXT:    [[TMP5:%.*]] = call fast { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv8f32(<vscale x 8 x float> [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP5]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP5]], 1
+; CHECK-NEXT:    [[RDX]] = fmul fast <vscale x 4 x float> [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i64 [[ITER_NEXT]], [[TMP3]]
+; CHECK-NEXT:    br i1 [[EXITCOND]], [[DOTSPLIT:label %.*]], label %[[RDX_LOOP]]
+; CHECK:       [[_SPLIT:.*:]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <vscale x 4 x float> [[RDX]], i64 0
+; CHECK-NEXT:    [[RDX_FINAL:%.*]] = fmul fast float 3.000000e+00, [[TMP8]]
+; CHECK-NEXT:    ret float [[RDX_FINAL]]
+;
+  %res = call fast float @llvm.vector.reduce.fmul.vnx4f32(float 3.0, <vscale x 4 x float> %vec)
+  ret float %res
+}
+
+; The reduction does not have the reassoc fast-math flag, so a sequential traversal is needed.
+define float @test_reduce_fmul_seq_expansion(<vscale x 4 x float> %vec) #0 {
+; CHECK-LABEL: define float @test_reduce_fmul_seq_expansion(
+; CHECK-SAME: <vscale x 4 x float> [[VEC:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    br label %[[RDX_LOOP:.*]]
+; CHECK:       [[RDX_LOOP]]:
+; CHECK-NEXT:    [[ITER:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[ITER_NEXT:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[RDX_PHI:%.*]] = phi float [ 3.000000e+00, [[TMP0]] ], [ [[RDX_FINAL:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[ITER_NEXT]] = add nuw nsw i64 [[ITER]], 1
+; CHECK-NEXT:    [[ELM:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i64 [[ITER]]
+; CHECK-NEXT:    [[RDX_FINAL]] = fmul float [[RDX_PHI]], [[ELM]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i64 [[ITER_NEXT]], [[TMP2]]
+; CHECK-NEXT:    br i1 [[EXITCOND]], [[DOTSPLIT:label %.*]], label %[[RDX_LOOP]]
+; CHECK:       [[_SPLIT:.*:]]
+; CHECK-NEXT:    ret float [[RDX_FINAL]]
+;
+  %res = call float @llvm.vector.reduce.fmul.vnx4f32(float 3.0, <vscale x 4 x float> %vec)
+  ret float %res
+}
+
+; Similar to the first test, but for integers instead of floats,
+; which makes a difference because there is no accumulator argument.
+define i32 @test_reduce_int_mul_expansion(<vscale x 4 x i32> %vec) #0 {
+; CHECK-LABEL: define i32 @test_reduce_int_mul_expansion(
+; CHECK-SAME: <vscale x 4 x i32> [[VEC:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.cttz.i64(i64 [[TMP2]], i1 true)
+; CHECK-NEXT:    br label %[[RDX_LOOP:.*]]
+; CHECK:       [[RDX_LOOP]]:
+; CHECK-NEXT:    [[ITER:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[ITER_NEXT:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[RDX_PHI:%.*]] = phi <vscale x 4 x i32> [ [[VEC]], [[TMP0]] ], [ [[RDX:%.*]], %[[RDX_LOOP]] ]
+; CHECK-NEXT:    [[ITER_NEXT]] = add nuw nsw i64 [[ITER]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[RDX_PHI]], i64 0)
+; CHECK-NEXT:    [[TMP5:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[TMP4]])
+; CHECK-NEXT:    [[TMP6:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[TMP5]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[TMP5]], 1
+; CHECK-NEXT:    [[RDX]] = mul <vscale x 4 x i32> [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i64 [[ITER_NEXT]], [[TMP3]]
+; CHECK-NEXT:    br i1 [[EXITCOND]], [[DOTSPLIT:label %.*]], label %[[RDX_LOOP]]
+; CHECK:       [[_SPLIT:.*:]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <vscale x 4 x i32> [[RDX]], i64 0
+; CHECK-NEXT:    ret i32 [[TMP8]]
+;
+  %res = call i32 @llvm.vector.reduce.mul.vnx4i32(<vscale x 4 x i32> %vec)
+  ret i32 %res
+}
+
+; This function has the attribute `vscale_range(2,2)`, which means that it can be
+; expanded just like a reduction on a fixed-sized vector without any loop.
+define float @test_fixed_vscale(<vscale x 4 x float> %vec) #1 {
+; CHECK-LABEL: define float @test_fixed_vscale(
+; CHECK-SAME: <vscale x 4 x float> [[VEC:%.*]]) #[[ATTR1:[0-9]+]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 0
+; CHECK-NEXT:    [[BIN_RDX:%.*]] = fmul float 3.000000e+00, [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 1
+; CHECK-NEXT:    [[BIN_RDX1:%.*]] = fmul float [[BIN_RDX]], [[TMP2]]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 2
+; CHECK-NEXT:    [[BIN_RDX2:%.*]] = fmul float [[BIN_RDX1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 3
+; CHECK-NEXT:    [[BIN_RDX3:%.*]] = fmul float [[BIN_RDX2]], [[TMP4]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 4
+; CHECK-NEXT:    [[BIN_RDX4:%.*]] = fmul float [[BIN_RDX3]], [[TMP5]]
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 5
+; CHECK-NEXT:    [[BIN_RDX5:%.*]] = fmul float [[BIN_RDX4]], [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 6
+; CHECK-NEXT:    [[BIN_RDX6:%.*]] = fmul float [[BIN_RDX5]], [[TMP7]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i32 7
+; CHECK-NEXT:    [[BIN_RDX15:%.*]] = fmul float [[BIN_RDX6]], [[TMP8]]
+; CHECK-NEXT:    ret float [[BIN_RDX15]]
+;
+  %res = call float @llvm.vector.reduce.fmul.vnx4f32(float 3.0, <vscale x 4 x float> %vec)
+  ret float %res
+}
+
+define float @test_fixed_vscale_log2_reduction(<vscale x 4 x float> %vec) #1 {
+; CHECK-LABEL: define float @test_fixed_vscale_log2_reduction(
+; CHECK-SAME: <vscale x 4 x float> [[VEC:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> poison, <vscale x 4 x float> [[VEC]], i64 0)
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv8f32(<vscale x 8 x float> [[TMP1]])
+; CHECK-NEXT:    [[TMP3:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP2]], 1
+; CHECK-NEXT:    [[RDX:%.*]] = fmul fast <vscale x 4 x float> [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP5:%.*]] = call fast <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> poison, <vscale x 4 x float> [[RDX]], i64 0)
+; CHECK-NEXT:    [[TMP6:%.*]] = call fast { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv8f32(<vscale x 8 x float> [[TMP5]])
+; CHECK-NEXT:    [[TMP7:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP6]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP6]], 1
+; CHECK-NEXT:    [[RDX1:%.*]] = fmul fast <vscale x 4 x float> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP9:%.*]] = call fast <vscale x 8 x float> @llvm.vector.insert.nxv8f32.nxv4f32(<vscale x 8 x float> poison, <vscale x 4 x float> [[RDX1]], i64 0)
+; CHECK-NEXT:    [[TMP10:%.*]] = call fast { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv8f32(<vscale x 8 x float> [[TMP9]])
+; CHECK-NEXT:    [[TMP11:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP10]], 0
+; CHECK-NEXT:    [[TMP12:%.*]] = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } [[TMP10]], 1
+; CHECK-NEXT:    [[RDX2:%.*]] = fmul fast <vscale x 4 x float> [[TMP11]], [[TMP12]]
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <vscale x 4 x float> [[RDX2]], i64 0
+; CHECK-NEXT:    [[RDX_FINAL:%.*]] = fmul fast float 3.000000e+00, [[TMP13]]
+; CHECK-NEXT:    ret float [[RDX_FINAL]]
+;
+  %res = call fast float @llvm.vector.reduce.fmul.vnx4f32(float 3.0, <vscale x 4 x float> %vec)
+  ret float %res
+}
+
+attributes #0 = { vscale_range(1,16) "target-features"="+sve" }
+attributes #1 = { vscale_range(2,2) "target-features"="+sve" }

>From a5a1223ef0b8e2962aec57de175754a576861168 Mon Sep 17 00:00:00 2001
From: Lou Knauer <lou.knauer at sipearl.com>
Date: Thu, 6 Mar 2025 12:27:06 +0100
Subject: [PATCH 2/2] Use DomTreeUpdater instead of updating DT directly

---
 llvm/lib/CodeGen/ExpandReductions.cpp | 46 ++++++++++++++-------------
 1 file changed, 24 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/CodeGen/ExpandReductions.cpp b/llvm/lib/CodeGen/ExpandReductions.cpp
index 797c83aea3309..84750f7f81212 100644
--- a/llvm/lib/CodeGen/ExpandReductions.cpp
+++ b/llvm/lib/CodeGen/ExpandReductions.cpp
@@ -36,25 +36,18 @@ using namespace llvm;
 
 namespace {
 
-void updateDomTreeForScalableExpansion(DominatorTree *DT, BasicBlock *Preheader,
-                                       BasicBlock *Loop, BasicBlock *Exit) {
-  DT->addNewBlock(Loop, Preheader);
-  DT->changeImmediateDominator(Exit, Loop);
-  assert(DT->verify(DominatorTree::VerificationLevel::Fast));
-}
-
 /// Expand a reduction on a scalable vector into a loop
 /// that iterates over one element after the other.
 Value *expandScalableReduction(IRBuilderBase &Builder, IntrinsicInst *II,
                                Value *Acc, Value *Vec,
                                Instruction::BinaryOps BinOp,
-                               DominatorTree *DT) {
+                               DomTreeUpdater &DTU) {
   ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
 
   // Split the original BB in two and create a new BB between them,
   // which will be a loop.
   BasicBlock *BeforeBB = II->getParent();
-  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
+  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, &DTU);
   BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
                                           BeforeBB->getParent(), AfterBB);
   BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
@@ -87,9 +80,9 @@ Value *expandScalableReduction(IRBuilderBase &Builder, IntrinsicInst *II,
   Value *Done = Builder.CreateCmp(CmpInst::ICMP_EQ, IVInc, NumElts, "exitcond");
   Builder.CreateCondBr(Done, AfterBB, LoopBB);
 
-  if (DT)
-    updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
-
+  DTU.applyUpdates({{DominatorTree::Insert, BeforeBB, LoopBB},
+                    {DominatorTree::Insert, LoopBB, AfterBB},
+                    {DominatorTree::Delete, BeforeBB, AfterBB}});
   return Rdx;
 }
 
@@ -99,7 +92,7 @@ Value *expandScalableReduction(IRBuilderBase &Builder, IntrinsicInst *II,
 Value *expandScalableTreeReduction(
     IRBuilderBase &Builder, IntrinsicInst *II, std::optional<Value *> Acc,
     Value *Vec, Instruction::BinaryOps BinOp,
-    function_ref<bool(Constant *)> IsNeutralElement, DominatorTree *DT,
+    function_ref<bool(Constant *)> IsNeutralElement, DomTreeUpdater &DTU,
     std::optional<unsigned> FixedVScale) {
   ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType());
   ScalableVectorType *VecTyX2 = ScalableVectorType::get(
@@ -133,7 +126,7 @@ Value *expandScalableTreeReduction(
   // Split the original BB in two and create a new BB between them,
   // which will be a loop.
   BasicBlock *BeforeBB = II->getParent();
-  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, DT);
+  BasicBlock *AfterBB = SplitBlock(BeforeBB, II, &DTU);
   BasicBlock *LoopBB = BasicBlock::Create(Builder.getContext(), "rdx.loop",
                                           BeforeBB->getParent(), AfterBB);
   BeforeBB->getTerminator()->setSuccessor(0, LoopBB);
@@ -186,15 +179,16 @@ Value *expandScalableTreeReduction(
     if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement(C))
       FinalVal = Builder.CreateBinOp(BinOp, *Acc, FinalVal, "rdx.final");
 
-  if (DT)
-    updateDomTreeForScalableExpansion(DT, BeforeBB, LoopBB, AfterBB);
+  DTU.applyUpdates({{DominatorTree::Insert, BeforeBB, LoopBB},
+                    {DominatorTree::Insert, LoopBB, AfterBB},
+                    {DominatorTree::Delete, BeforeBB, AfterBB}});
 
   return FinalVal;
 }
 
 std::pair<bool, bool> expandReductions(Function &F,
                                        const TargetTransformInfo *TTI,
-                                       DominatorTree *DT) {
+                                       DomTreeUpdater &DTU) {
   bool Changed = false, CFGChanged = false;
   SmallVector<IntrinsicInst *, 4> Worklist;
   for (auto &I : instructions(F)) {
@@ -270,9 +264,9 @@ std::pair<bool, bool> expandReductions(Function &F,
                   llvm_unreachable("Binop not handled");
                 }
               },
-              DT, FixedVScale);
+              DTU, FixedVScale);
         else
-          Rdx = expandScalableReduction(Builder, II, Acc, Vec, RdxOpcode, DT);
+          Rdx = expandScalableReduction(Builder, II, Acc, Vec, RdxOpcode, DTU);
         break;
       }
 
@@ -335,7 +329,7 @@ std::pair<bool, bool> expandReductions(Function &F,
         Rdx = expandScalableTreeReduction(
             Builder, II, std::nullopt, Vec, Instruction::BinaryOps(RdxOpcode),
             [](Constant *C) -> bool { llvm_unreachable("No accumulator!"); },
-            DT, FixedVScale);
+            DTU, FixedVScale);
         break;
       }
 
@@ -363,6 +357,11 @@ std::pair<bool, bool> expandReductions(Function &F,
     II->eraseFromParent();
     Changed = true;
   }
+
+  if (DTU.hasDomTree() && DTU.hasPendingUpdates()) {
+    DTU.flush();
+    assert(DTU.getDomTree().verify(DominatorTree::VerificationLevel::Fast));
+  }
   return {CFGChanged, Changed};
 }
 
@@ -376,7 +375,9 @@ class ExpandReductions : public FunctionPass {
   bool runOnFunction(Function &F) override {
     const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     auto *DTA = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
-    return expandReductions(F, TTI, DTA ? &DTA->getDomTree() : nullptr).second;
+    DomTreeUpdater DTU(DTA ? &DTA->getDomTree() : nullptr,
+                       DomTreeUpdater::UpdateStrategy::Lazy);
+    return expandReductions(F, TTI, DTU).second;
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -402,7 +403,8 @@ PreservedAnalyses ExpandReductionsPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
-  auto [CFGChanged, Changed] = expandReductions(F, &TTI, DT);
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  auto [CFGChanged, Changed] = expandReductions(F, &TTI, DTU);
   if (!Changed)
     return PreservedAnalyses::all();
   PreservedAnalyses PA;



More information about the llvm-commits mailing list