[llvm] [VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp (PR #138406)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue May 27 10:30:34 PDT 2025


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/138406

>From 4d8ba5c67f01e488cde5a8c721fc7e58ee0b3dd9 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 23 May 2025 12:54:25 +0100
Subject: [PATCH 1/4] Precommit tests

---
 .../VectorCombine/intrinsic-scalarize.ll      | 56 +++++++++++++++++++
 .../VectorCombine/unary-op-scalarize.ll       | 26 +++++++++
 2 files changed, 82 insertions(+)
 create mode 100644 llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll

diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
index e7683d72a052d..2a2e37e0ab54b 100644
--- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
+++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
@@ -96,6 +96,62 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) {
   ret <4 x i32> %v
 }
 
+define <4 x float> @fabs_fixed(float %x) {
+; CHECK-LABEL: define <4 x float> @fabs_fixed(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> [[X_INSERT]])
+; CHECK-NEXT:    ret <4 x float> [[V]]
+;
+  %x.insert = insertelement <4 x float> poison, float %x, i32 0
+  %v = call <4 x float> @llvm.fabs(<4 x float> %x.insert)
+  ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fabs_scalable(float %x) {
+; CHECK-LABEL: define <vscale x 4 x float> @fabs_scalable(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> [[X_INSERT]])
+; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
+;
+  %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+  %v = call <vscale x 4 x float> @llvm.fabs(<vscale x 4 x float> %x.insert)
+  ret <vscale x 4 x float> %v
+}
+
+define <4 x float> @fma_fixed(float %x, float %y, float %z) {
+; CHECK-LABEL: define <4 x float> @fma_fixed(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[Y_INSERT:%.*]] = insertelement <4 x float> poison, float [[Y]], i32 0
+; CHECK-NEXT:    [[Z_INSERT:%.*]] = insertelement <4 x float> poison, float [[Z]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[X_INSERT]], <4 x float> [[Y_INSERT]], <4 x float> [[Z_INSERT]])
+; CHECK-NEXT:    ret <4 x float> [[V]]
+;
+  %x.insert = insertelement <4 x float> poison, float %x, i32 0
+  %y.insert = insertelement <4 x float> poison, float %y, i32 0
+  %z.insert = insertelement <4 x float> poison, float %z, i32 0
+  %v = call <4 x float> @llvm.fma(<4 x float> %x.insert, <4 x float> %y.insert, <4 x float> %z.insert)
+  ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
+; CHECK-LABEL: define <vscale x 4 x float> @fma_scalable(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[Y_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[Y]], i32 0
+; CHECK-NEXT:    [[Z_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[Z]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> [[X_INSERT]], <vscale x 4 x float> [[Y_INSERT]], <vscale x 4 x float> [[Z_INSERT]])
+; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
+;
+  %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+  %y.insert = insertelement <vscale x 4 x float> poison, float %y, i32 0
+  %z.insert = insertelement <vscale x 4 x float> poison, float %z, i32 0
+  %v = call <vscale x 4 x float> @llvm.fma(<vscale x 4 x float> %x.insert, <vscale x 4 x float> %y.insert, <vscale x 4 x float> %z.insert)
+  ret <vscale x 4 x float> %v
+}
+
 ; TODO: We should be able to scalarize this if we preserve the scalar argument.
 define <4 x float> @scalar_argument(float %x) {
 ; CHECK-LABEL: define <4 x float> @scalar_argument(
diff --git a/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
new file mode 100644
index 0000000000000..fd40b15706afb
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
@@ -0,0 +1,26 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -p vector-combine | FileCheck %s
+
+define <4 x float> @fneg_fixed(float %x) {
+; CHECK-LABEL: define <4 x float> @fneg_fixed(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = fneg <4 x float> [[X_INSERT]]
+; CHECK-NEXT:    ret <4 x float> [[V]]
+;
+  %x.insert = insertelement <4 x float> poison, float %x, i32 0
+  %v = fneg <4 x float> %x.insert
+  ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fneg_scalable(float %x) {
+; CHECK-LABEL: define <vscale x 4 x float> @fneg_scalable(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
+; CHECK-NEXT:    [[V:%.*]] = fneg <vscale x 4 x float> [[X_INSERT]]
+; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
+;
+  %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+  %v = fneg <vscale x 4 x float> %x.insert
+  ret <vscale x 4 x float> %v
+}

>From 7bf88886a8e65b66b09cf56bbe95c89b18fb2c33 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 23 May 2025 12:54:59 +0100
Subject: [PATCH 2/4] [VectorCombine] Scalarize nary ops and intrinsics

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 190 +++++++++---------
 .../VectorCombine/intrinsic-scalarize.ll      |  24 +--
 .../VectorCombine/unary-op-scalarize.ll       |   8 +-
 3 files changed, 114 insertions(+), 108 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4413284aa3c2a..bf33292544497 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -47,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
 STATISTIC(NumVecBO, "Number of vector binops formed");
 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
-STATISTIC(NumScalarBO, "Number of scalar binops formed");
+STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
 STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
 
@@ -114,7 +114,7 @@ class VectorCombine {
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
-  bool scalarizeBinopOrCmp(Instruction &I);
+  bool scalarizeOpOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
@@ -1018,28 +1018,20 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
   return true;
 }
 
-/// Match a vector binop, compare or binop-like intrinsic with at least one
-/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
+/// Match a vector op/compare/intrinsic with at least one
+/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
 /// by insertelement.
-bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
-  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-  Value *Ins0, *Ins1;
-  if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
-      !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
-    // TODO: Allow unary and ternary intrinsics
-    // TODO: Allow intrinsics with different argument types
-    // TODO: Allow intrinsics with scalar arguments
-    if (auto *II = dyn_cast<IntrinsicInst>(&I);
-        II && II->arg_size() == 2 &&
-        isTriviallyVectorizable(II->getIntrinsicID()) &&
-        all_of(II->args(),
-               [&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
-      Ins0 = II->getArgOperand(0);
-      Ins1 = II->getArgOperand(1);
-    } else {
+bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
+  if (!isa<UnaryOperator, BinaryOperator, CmpInst, IntrinsicInst>(I))
+    return false;
+
+  // TODO: Allow intrinsics with different argument types
+  // TODO: Allow intrinsics with scalar arguments
+  if (auto *II = dyn_cast<IntrinsicInst>(&I))
+    if (!isTriviallyVectorizable(II->getIntrinsicID()) ||
+        !all_of(II->args(),
+                [&II](Value *Arg) { return Arg->getType() == II->getType(); }))
       return false;
-    }
-  }
 
   // Do not convert the vector condition of a vector select into a scalar
   // condition. That may cause problems for codegen because of differences in
@@ -1050,42 +1042,47 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
         return false;
 
-  // Match against one or both scalar values being inserted into constant
-  // vectors:
-  // vec_op VecC0, (inselt VecC1, V1, Index)
-  // vec_op (inselt VecC0, V0, Index), VecC1
-  // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
-  // TODO: Deal with mismatched index constants and variable indexes?
-  Constant *VecC0 = nullptr, *VecC1 = nullptr;
-  Value *V0 = nullptr, *V1 = nullptr;
-  uint64_t Index0 = 0, Index1 = 0;
-  if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
-                               m_ConstantInt(Index0))) &&
-      !match(Ins0, m_Constant(VecC0)))
-    return false;
-  if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
-                               m_ConstantInt(Index1))) &&
-      !match(Ins1, m_Constant(VecC1)))
-    return false;
-
-  bool IsConst0 = !V0;
-  bool IsConst1 = !V1;
-  if (IsConst0 && IsConst1)
-    return false;
-  if (!IsConst0 && !IsConst1 && Index0 != Index1)
-    return false;
+  // Match constant vectors or scalars being inserted into constant vectors:
+  // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
+  SmallVector<Constant *> VecCs;
+  SmallVector<Value *> ScalarOps;
+  std::optional<uint64_t> Index;
+
+  auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args()
+                                   : I.operand_values();
+  for (Value *Op : Ops) {
+    Constant *VecC;
+    Value *V;
+    uint64_t InsIdx = 0;
+    VectorType *OpTy = cast<VectorType>(Op->getType());
+    if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
+                              m_ConstantInt(InsIdx)))) {
+      // Bail if any inserts are out of bounds.
+      if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
+        return false;
+      // All inserts must have the same index.
+      // TODO: Deal with mismatched index constants and variable indexes?
+      if (!Index)
+        Index = InsIdx;
+      else if (InsIdx != *Index)
+        return false;
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(V);
+    } else if (match(Op, m_Constant(VecC))) {
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(nullptr);
+    } else {
+      return false;
+    }
+  }
 
-  auto *VecTy0 = cast<VectorType>(Ins0->getType());
-  auto *VecTy1 = cast<VectorType>(Ins1->getType());
-  if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
-      VecTy1->getElementCount().getKnownMinValue() <= Index1)
+  // Bail if all operands are constant.
+  if (!Index.has_value())
     return false;
 
-  uint64_t Index = IsConst0 ? Index1 : Index0;
-  Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
-  Type *VecTy = I.getType();
+  VectorType *VecTy = cast<VectorType>(I.getType());
+  Type *ScalarTy = VecTy->getScalarType();
   assert(VecTy->isVectorTy() &&
-         (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
           ScalarTy->isPointerTy()) &&
          "Unexpected types for insert element into binop or cmp");
@@ -1098,7 +1095,7 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
     VectorOpCost = TTI.getCmpSelInstrCost(
         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
-  } else if (isa<BinaryOperator>(I)) {
+  } else if (isa<UnaryOperator, BinaryOperator>(I)) {
     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   } else {
@@ -1115,29 +1112,36 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
 
   // Fold the vector constants in the original vectors into a new base vector to
   // get more accurate cost modelling.
-  Value *NewVecC;
-  if (isa<CmpInst>(I))
-    NewVecC = ConstantFoldCompareInstOperands(Pred, VecC0, VecC1, *DL);
+  Value *NewVecC = nullptr;
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
+                                              VecCs[1], *DL);
+  else if (isa<UnaryOperator>(I))
+    NewVecC = ConstantFoldUnaryOpOperand((Instruction::UnaryOps)Opcode,
+                                         VecCs[0], *DL);
   else if (isa<BinaryOperator>(I))
     NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
-                                           VecC0, VecC1, *DL);
-  else
-    NewVecC = ConstantFoldBinaryIntrinsic(
-        cast<IntrinsicInst>(I).getIntrinsicID(), VecC0, VecC1, I.getType(), &I);
+                                           VecCs[0], VecCs[1], *DL);
+  else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
+    NewVecC =
+        ConstantFoldBinaryIntrinsic(cast<IntrinsicInst>(I).getIntrinsicID(),
+                                    VecCs[0], VecCs[1], I.getType(), &I);
 
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.
-  InstructionCost InsertCostNewVecC = TTI.getVectorInstrCost(
-      Instruction::InsertElement, VecTy, CostKind, Index, NewVecC);
-  InstructionCost InsertCostV0 = TTI.getVectorInstrCost(
-      Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
-  InstructionCost InsertCostV1 = TTI.getVectorInstrCost(
-      Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
-  InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
-                            (IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
-  InstructionCost NewCost = ScalarOpCost + InsertCostNewVecC +
-                            (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCostV0) +
-                            (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCostV1);
+  InstructionCost OldCost = VectorOpCost;
+  InstructionCost NewCost =
+      ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
+                                            CostKind, *Index, NewVecC);
+  for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
+    if (!Scalar)
+      continue;
+    InstructionCost InsertCost = TTI.getVectorInstrCost(
+        Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
+    OldCost += InsertCost;
+    NewCost += !Op->hasOneUse() * InsertCost;
+  }
+
   // We want to scalarize unless the vector variant actually has lower cost.
   if (OldCost < NewCost || !NewCost.isValid())
     return false;
@@ -1146,25 +1150,25 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
   // inselt NewVecC, (scalar_op V0, V1), Index
   if (isa<CmpInst>(I))
     ++NumScalarCmp;
-  else if (isa<BinaryOperator>(I))
-    ++NumScalarBO;
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    ++NumScalarOps;
   else if (isa<IntrinsicInst>(I))
     ++NumScalarIntrinsic;
 
   // For constant cases, extract the scalar element, this should constant fold.
-  if (IsConst0)
-    V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
-  if (IsConst1)
-    V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
+  for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
+    if (!Scalar)
+      ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
+          cast<Constant>(VecC), Builder.getInt64(*Index));
 
   Value *Scalar;
-  if (isa<CmpInst>(I))
-    Scalar = Builder.CreateCmp(Pred, V0, V1);
-  else if (isa<BinaryOperator>(I))
-    Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
   else
     Scalar = Builder.CreateIntrinsic(
-        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
+        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), ScalarOps);
 
   Scalar->setName(I.getName() + ".scalar");
 
@@ -1175,16 +1179,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
 
   // Create a new base vector if the constant folding failed.
   if (!NewVecC) {
-    if (isa<CmpInst>(I))
-      NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
-    else if (isa<BinaryOperator>(I))
-      NewVecC =
-          Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
+    SmallVector<Value *> VecCValues;
+    VecCValues.reserve(VecCs.size());
+    append_range(VecCValues, VecCs);
+    if (auto *CI = dyn_cast<CmpInst>(&I))
+      NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
+    else if (isa<UnaryOperator, BinaryOperator>(I))
+      NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
     else
       NewVecC = Builder.CreateIntrinsic(
-          VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
+          VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), VecCValues);
   }
-  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
+  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
   replaceValue(I, *Insert);
   return true;
 }
@@ -3570,7 +3576,7 @@ bool VectorCombine::run() {
     // This transform works with scalable and fixed vectors
     // TODO: Identify and allow other scalable transforms
     if (IsVectorType) {
-      MadeChange |= scalarizeBinopOrCmp(I);
+      MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
index 2a2e37e0ab54b..58b7f8de004d0 100644
--- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
+++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
@@ -99,8 +99,9 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) {
 define <4 x float> @fabs_fixed(float %x) {
 ; CHECK-LABEL: define <4 x float> @fabs_fixed(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> [[X_INSERT]])
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> poison)
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <4 x float> [[V]]
 ;
   %x.insert = insertelement <4 x float> poison, float %x, i32 0
@@ -111,8 +112,9 @@ define <4 x float> @fabs_fixed(float %x) {
 define <vscale x 4 x float> @fabs_scalable(float %x) {
 ; CHECK-LABEL: define <vscale x 4 x float> @fabs_scalable(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> [[X_INSERT]])
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> poison)
+; CHECK-NEXT:    [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
 ;
   %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
@@ -123,10 +125,9 @@ define <vscale x 4 x float> @fabs_scalable(float %x) {
 define <4 x float> @fma_fixed(float %x, float %y, float %z) {
 ; CHECK-LABEL: define <4 x float> @fma_fixed(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[Y_INSERT:%.*]] = insertelement <4 x float> poison, float [[Y]], i32 0
-; CHECK-NEXT:    [[Z_INSERT:%.*]] = insertelement <4 x float> poison, float [[Z]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[X_INSERT]], <4 x float> [[Y_INSERT]], <4 x float> [[Z_INSERT]])
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> poison, <4 x float> poison, <4 x float> poison)
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <4 x float> [[V]]
 ;
   %x.insert = insertelement <4 x float> poison, float %x, i32 0
@@ -139,10 +140,9 @@ define <4 x float> @fma_fixed(float %x, float %y, float %z) {
 define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
 ; CHECK-LABEL: define <vscale x 4 x float> @fma_scalable(
 ; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[Y_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[Y]], i32 0
-; CHECK-NEXT:    [[Z_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[Z]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> [[X_INSERT]], <vscale x 4 x float> [[Y_INSERT]], <vscale x 4 x float> [[Z_INSERT]])
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> poison, <vscale x 4 x float> poison, <vscale x 4 x float> poison)
+; CHECK-NEXT:    [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
 ;
   %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
diff --git a/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
index fd40b15706afb..45d53c84c870d 100644
--- a/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
+++ b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
@@ -4,8 +4,8 @@
 define <4 x float> @fneg_fixed(float %x) {
 ; CHECK-LABEL: define <4 x float> @fneg_fixed(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = fneg <4 x float> [[X_INSERT]]
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x float> poison, float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <4 x float> [[V]]
 ;
   %x.insert = insertelement <4 x float> poison, float %x, i32 0
@@ -16,8 +16,8 @@ define <4 x float> @fneg_fixed(float %x) {
 define <vscale x 4 x float> @fneg_scalable(float %x) {
 ; CHECK-LABEL: define <vscale x 4 x float> @fneg_scalable(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <vscale x 4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = fneg <vscale x 4 x float> [[X_INSERT]]
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    [[V:%.*]] = insertelement <vscale x 4 x float> poison, float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <vscale x 4 x float> [[V]]
 ;
   %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0

>From 5f57fb62ed02e9ae349776f5200aede125475260 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 23 May 2025 19:10:03 +0100
Subject: [PATCH 3/4] Use dyn_cast

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index bf33292544497..1c13d95bf7ac0 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1122,10 +1122,11 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
   else if (isa<BinaryOperator>(I))
     NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
                                            VecCs[0], VecCs[1], *DL);
-  else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
-    NewVecC =
-        ConstantFoldBinaryIntrinsic(cast<IntrinsicInst>(I).getIntrinsicID(),
-                                    VecCs[0], VecCs[1], I.getType(), &I);
+  else if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+    if (II->arg_size() == 2)
+      NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
+                                            VecCs[1], II->getType(), II);
+  }
 
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.

>From e99fed62d610cc149b9ebd69433a2e0ae2b48855 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Tue, 27 May 2025 18:18:46 +0100
Subject: [PATCH 4/4] Perform all dyn_casts at the start

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 69 +++++++++----------
 1 file changed, 33 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 1c13d95bf7ac0..7336de442f370 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1022,22 +1022,26 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
 /// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
 /// by insertelement.
 bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
-  if (!isa<UnaryOperator, BinaryOperator, CmpInst, IntrinsicInst>(I))
+  auto *UO = dyn_cast<UnaryOperator>(&I);
+  auto *BO = dyn_cast<BinaryOperator>(&I);
+  auto *CI = dyn_cast<CmpInst>(&I);
+  auto *II = dyn_cast<IntrinsicInst>(&I);
+  if (!UO && !BO && !CI && !II)
     return false;
 
   // TODO: Allow intrinsics with different argument types
   // TODO: Allow intrinsics with scalar arguments
-  if (auto *II = dyn_cast<IntrinsicInst>(&I))
-    if (!isTriviallyVectorizable(II->getIntrinsicID()) ||
-        !all_of(II->args(),
-                [&II](Value *Arg) { return Arg->getType() == II->getType(); }))
-      return false;
+  if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) ||
+             !all_of(II->args(), [&II](Value *Arg) {
+               return Arg->getType() == II->getType();
+             })))
+    return false;
 
   // Do not convert the vector condition of a vector select into a scalar
   // condition. That may cause problems for codegen because of differences in
   // boolean formats and register-file transfers.
   // TODO: Can we account for that in the cost model?
-  if (isa<CmpInst>(I))
+  if (CI)
     for (User *U : I.users())
       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
         return false;
@@ -1048,8 +1052,7 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
   SmallVector<Value *> ScalarOps;
   std::optional<uint64_t> Index;
 
-  auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args()
-                                   : I.operand_values();
+  auto Ops = II ? II->args() : I.operand_values();
   for (Value *Op : Ops) {
     Constant *VecC;
     Value *V;
@@ -1089,17 +1092,16 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
 
   unsigned Opcode = I.getOpcode();
   InstructionCost ScalarOpCost, VectorOpCost;
-  if (isa<CmpInst>(I)) {
-    CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
+  if (CI) {
+    CmpInst::Predicate Pred = CI->getPredicate();
     ScalarOpCost = TTI.getCmpSelInstrCost(
         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
     VectorOpCost = TTI.getCmpSelInstrCost(
         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
-  } else if (isa<UnaryOperator, BinaryOperator>(I)) {
+  } else if (UO || BO) {
     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   } else {
-    auto *II = cast<IntrinsicInst>(&I);
     IntrinsicCostAttributes ScalarICA(
         II->getIntrinsicID(), ScalarTy,
         SmallVector<Type *>(II->arg_size(), ScalarTy));
@@ -1113,20 +1115,16 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
   // Fold the vector constants in the original vectors into a new base vector to
   // get more accurate cost modelling.
   Value *NewVecC = nullptr;
-  if (auto *CI = dyn_cast<CmpInst>(&I))
+  if (CI)
     NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
                                               VecCs[1], *DL);
-  else if (isa<UnaryOperator>(I))
-    NewVecC = ConstantFoldUnaryOpOperand((Instruction::UnaryOps)Opcode,
-                                         VecCs[0], *DL);
-  else if (isa<BinaryOperator>(I))
-    NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
-                                           VecCs[0], VecCs[1], *DL);
-  else if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
-    if (II->arg_size() == 2)
-      NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
-                                            VecCs[1], II->getType(), II);
-  }
+  else if (UO)
+    NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL);
+  else if (BO)
+    NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL);
+  else if (II->arg_size() == 2)
+    NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
+                                          VecCs[1], II->getType(), II);
 
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.
@@ -1149,11 +1147,11 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
 
   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
   // inselt NewVecC, (scalar_op V0, V1), Index
-  if (isa<CmpInst>(I))
+  if (CI)
     ++NumScalarCmp;
-  else if (isa<UnaryOperator, BinaryOperator>(I))
+  else if (UO || BO)
     ++NumScalarOps;
-  else if (isa<IntrinsicInst>(I))
+  else
     ++NumScalarIntrinsic;
 
   // For constant cases, extract the scalar element, this should constant fold.
@@ -1163,13 +1161,12 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
           cast<Constant>(VecC), Builder.getInt64(*Index));
 
   Value *Scalar;
-  if (auto *CI = dyn_cast<CmpInst>(&I))
+  if (CI)
     Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
-  else if (isa<UnaryOperator, BinaryOperator>(I))
+  else if (UO || BO)
     Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
   else
-    Scalar = Builder.CreateIntrinsic(
-        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), ScalarOps);
+    Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
 
   Scalar->setName(I.getName() + ".scalar");
 
@@ -1183,13 +1180,13 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
     SmallVector<Value *> VecCValues;
     VecCValues.reserve(VecCs.size());
     append_range(VecCValues, VecCs);
-    if (auto *CI = dyn_cast<CmpInst>(&I))
+    if (CI)
       NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
-    else if (isa<UnaryOperator, BinaryOperator>(I))
+    else if (UO || BO)
       NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
     else
-      NewVecC = Builder.CreateIntrinsic(
-          VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), VecCValues);
+      NewVecC =
+          Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues);
   }
   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
   replaceValue(I, *Insert);



More information about the llvm-commits mailing list