[llvm] 2b9ded6 - [VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp (#138406)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 01:45:57 PDT 2025
Author: Luke Lau
Date: 2025-05-28T09:45:54+01:00
New Revision: 2b9ded64b0221f4159ab603518c5f88edb8bf958
URL: https://github.com/llvm/llvm-project/commit/2b9ded64b0221f4159ab603518c5f88edb8bf958
DIFF: https://github.com/llvm/llvm-project/commit/2b9ded64b0221f4159ab603518c5f88edb8bf958.diff
LOG: [VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp (#138406)
This adds support for unary operands, and unary + ternary intrinsics in
scalarizeOpOrCmp (FKA scalarizeBinOpOrCmp).
The motivation behind this is to scalarize more intrinsics in
VectorCombine rather than in DAGCombine, so we can sink splats across
basic blocks: see https://github.com/llvm/llvm-project/pull/137786
The main change required is to generalize the existing VecC0/VecC1 rules
across n-ary ops:
- An operand can either be a constant vector or an insert of a scalar
into a constant vector
- If it's an insert, the index needs to be static and in bounds
- If it's an insert, all indices need to be the same across all operands
- If all the operands are constant vectors, bail as it will get constant
folded anyway
Added:
llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
Modified:
llvm/lib/Transforms/Vectorize/VectorCombine.cpp
llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 4413284aa3c2a..7336de442f370 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -47,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
STATISTIC(NumVecBO, "Number of vector binops formed");
STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
-STATISTIC(NumScalarBO, "Number of scalar binops formed");
+STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
STATISTIC(NumScalarCmp, "Number of scalar compares formed");
STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
@@ -114,7 +114,7 @@ class VectorCombine {
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
- bool scalarizeBinopOrCmp(Instruction &I);
+ bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
@@ -1018,91 +1018,90 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
return true;
}
-/// Match a vector binop, compare or binop-like intrinsic with at least one
-/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
+/// Match a vector op/compare/intrinsic with at least one
+/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
/// by insertelement.
-bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
- CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
- Value *Ins0, *Ins1;
- if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
- !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
- // TODO: Allow unary and ternary intrinsics
- // TODO: Allow intrinsics with
diff erent argument types
- // TODO: Allow intrinsics with scalar arguments
- if (auto *II = dyn_cast<IntrinsicInst>(&I);
- II && II->arg_size() == 2 &&
- isTriviallyVectorizable(II->getIntrinsicID()) &&
- all_of(II->args(),
- [&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
- Ins0 = II->getArgOperand(0);
- Ins1 = II->getArgOperand(1);
- } else {
- return false;
- }
- }
+bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
+ auto *UO = dyn_cast<UnaryOperator>(&I);
+ auto *BO = dyn_cast<BinaryOperator>(&I);
+ auto *CI = dyn_cast<CmpInst>(&I);
+ auto *II = dyn_cast<IntrinsicInst>(&I);
+ if (!UO && !BO && !CI && !II)
+ return false;
+
+ // TODO: Allow intrinsics with
diff erent argument types
+ // TODO: Allow intrinsics with scalar arguments
+ if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) ||
+ !all_of(II->args(), [&II](Value *Arg) {
+ return Arg->getType() == II->getType();
+ })))
+ return false;
// Do not convert the vector condition of a vector select into a scalar
// condition. That may cause problems for codegen because of
diff erences in
// boolean formats and register-file transfers.
// TODO: Can we account for that in the cost model?
- if (isa<CmpInst>(I))
+ if (CI)
for (User *U : I.users())
if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
return false;
- // Match against one or both scalar values being inserted into constant
- // vectors:
- // vec_op VecC0, (inselt VecC1, V1, Index)
- // vec_op (inselt VecC0, V0, Index), VecC1
- // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
- // TODO: Deal with mismatched index constants and variable indexes?
- Constant *VecC0 = nullptr, *VecC1 = nullptr;
- Value *V0 = nullptr, *V1 = nullptr;
- uint64_t Index0 = 0, Index1 = 0;
- if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
- m_ConstantInt(Index0))) &&
- !match(Ins0, m_Constant(VecC0)))
- return false;
- if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
- m_ConstantInt(Index1))) &&
- !match(Ins1, m_Constant(VecC1)))
- return false;
-
- bool IsConst0 = !V0;
- bool IsConst1 = !V1;
- if (IsConst0 && IsConst1)
- return false;
- if (!IsConst0 && !IsConst1 && Index0 != Index1)
- return false;
+ // Match constant vectors or scalars being inserted into constant vectors:
+ // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
+ SmallVector<Constant *> VecCs;
+ SmallVector<Value *> ScalarOps;
+ std::optional<uint64_t> Index;
+
+ auto Ops = II ? II->args() : I.operand_values();
+ for (Value *Op : Ops) {
+ Constant *VecC;
+ Value *V;
+ uint64_t InsIdx = 0;
+ VectorType *OpTy = cast<VectorType>(Op->getType());
+ if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
+ m_ConstantInt(InsIdx)))) {
+ // Bail if any inserts are out of bounds.
+ if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
+ return false;
+ // All inserts must have the same index.
+ // TODO: Deal with mismatched index constants and variable indexes?
+ if (!Index)
+ Index = InsIdx;
+ else if (InsIdx != *Index)
+ return false;
+ VecCs.push_back(VecC);
+ ScalarOps.push_back(V);
+ } else if (match(Op, m_Constant(VecC))) {
+ VecCs.push_back(VecC);
+ ScalarOps.push_back(nullptr);
+ } else {
+ return false;
+ }
+ }
- auto *VecTy0 = cast<VectorType>(Ins0->getType());
- auto *VecTy1 = cast<VectorType>(Ins1->getType());
- if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
- VecTy1->getElementCount().getKnownMinValue() <= Index1)
+ // Bail if all operands are constant.
+ if (!Index.has_value())
return false;
- uint64_t Index = IsConst0 ? Index1 : Index0;
- Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
- Type *VecTy = I.getType();
+ VectorType *VecTy = cast<VectorType>(I.getType());
+ Type *ScalarTy = VecTy->getScalarType();
assert(VecTy->isVectorTy() &&
- (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
ScalarTy->isPointerTy()) &&
"Unexpected types for insert element into binop or cmp");
unsigned Opcode = I.getOpcode();
InstructionCost ScalarOpCost, VectorOpCost;
- if (isa<CmpInst>(I)) {
- CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
+ if (CI) {
+ CmpInst::Predicate Pred = CI->getPredicate();
ScalarOpCost = TTI.getCmpSelInstrCost(
Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
VectorOpCost = TTI.getCmpSelInstrCost(
Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
- } else if (isa<BinaryOperator>(I)) {
+ } else if (UO || BO) {
ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
} else {
- auto *II = cast<IntrinsicInst>(&I);
IntrinsicCostAttributes ScalarICA(
II->getIntrinsicID(), ScalarTy,
SmallVector<Type *>(II->arg_size(), ScalarTy));
@@ -1115,56 +1114,59 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
// Fold the vector constants in the original vectors into a new base vector to
// get more accurate cost modelling.
- Value *NewVecC;
- if (isa<CmpInst>(I))
- NewVecC = ConstantFoldCompareInstOperands(Pred, VecC0, VecC1, *DL);
- else if (isa<BinaryOperator>(I))
- NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
- VecC0, VecC1, *DL);
- else
- NewVecC = ConstantFoldBinaryIntrinsic(
- cast<IntrinsicInst>(I).getIntrinsicID(), VecC0, VecC1, I.getType(), &I);
+ Value *NewVecC = nullptr;
+ if (CI)
+ NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
+ VecCs[1], *DL);
+ else if (UO)
+ NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL);
+ else if (BO)
+ NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL);
+ else if (II->arg_size() == 2)
+ NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
+ VecCs[1], II->getType(), II);
// Get cost estimate for the insert element. This cost will factor into
// both sequences.
- InstructionCost InsertCostNewVecC = TTI.getVectorInstrCost(
- Instruction::InsertElement, VecTy, CostKind, Index, NewVecC);
- InstructionCost InsertCostV0 = TTI.getVectorInstrCost(
- Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
- InstructionCost InsertCostV1 = TTI.getVectorInstrCost(
- Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
- InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
- (IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
- InstructionCost NewCost = ScalarOpCost + InsertCostNewVecC +
- (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCostV0) +
- (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCostV1);
+ InstructionCost OldCost = VectorOpCost;
+ InstructionCost NewCost =
+ ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
+ CostKind, *Index, NewVecC);
+ for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
+ if (!Scalar)
+ continue;
+ InstructionCost InsertCost = TTI.getVectorInstrCost(
+ Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
+ OldCost += InsertCost;
+ NewCost += !Op->hasOneUse() * InsertCost;
+ }
+
// We want to scalarize unless the vector variant actually has lower cost.
if (OldCost < NewCost || !NewCost.isValid())
return false;
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
// inselt NewVecC, (scalar_op V0, V1), Index
- if (isa<CmpInst>(I))
+ if (CI)
++NumScalarCmp;
- else if (isa<BinaryOperator>(I))
- ++NumScalarBO;
- else if (isa<IntrinsicInst>(I))
+ else if (UO || BO)
+ ++NumScalarOps;
+ else
++NumScalarIntrinsic;
// For constant cases, extract the scalar element, this should constant fold.
- if (IsConst0)
- V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
- if (IsConst1)
- V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
+ for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
+ if (!Scalar)
+ ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
+ cast<Constant>(VecC), Builder.getInt64(*Index));
Value *Scalar;
- if (isa<CmpInst>(I))
- Scalar = Builder.CreateCmp(Pred, V0, V1);
- else if (isa<BinaryOperator>(I))
- Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
+ if (CI)
+ Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
+ else if (UO || BO)
+ Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
else
- Scalar = Builder.CreateIntrinsic(
- ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
+ Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
Scalar->setName(I.getName() + ".scalar");
@@ -1175,16 +1177,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
// Create a new base vector if the constant folding failed.
if (!NewVecC) {
- if (isa<CmpInst>(I))
- NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
- else if (isa<BinaryOperator>(I))
- NewVecC =
- Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
+ SmallVector<Value *> VecCValues;
+ VecCValues.reserve(VecCs.size());
+ append_range(VecCValues, VecCs);
+ if (CI)
+ NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
+ else if (UO || BO)
+ NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
else
- NewVecC = Builder.CreateIntrinsic(
- VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
+ NewVecC =
+ Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues);
}
- Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
+ Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
replaceValue(I, *Insert);
return true;
}
@@ -3570,7 +3574,7 @@ bool VectorCombine::run() {
// This transform works with scalable and fixed vectors
// TODO: Identify and allow other scalable transforms
if (IsVectorType) {
- MadeChange |= scalarizeBinopOrCmp(I);
+ MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
index e7683d72a052d..58b7f8de004d0 100644
--- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
+++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
@@ -96,6 +96,62 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) {
ret <4 x i32> %v
}
+define <4 x float> @fabs_fixed(float %x) {
+; CHECK-LABEL: define <4 x float> @fabs_fixed(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> poison)
+; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <4 x float> [[V]]
+;
+ %x.insert = insertelement <4 x float> poison, float %x, i32 0
+ %v = call <4 x float> @llvm.fabs(<4 x float> %x.insert)
+ ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fabs_scalable(float %x) {
+; CHECK-LABEL: define <vscale x 4 x float> @fabs_scalable(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> poison)
+; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
+;
+ %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+ %v = call <vscale x 4 x float> @llvm.fabs(<vscale x 4 x float> %x.insert)
+ ret <vscale x 4 x float> %v
+}
+
+define <4 x float> @fma_fixed(float %x, float %y, float %z) {
+; CHECK-LABEL: define <4 x float> @fma_fixed(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> poison, <4 x float> poison, <4 x float> poison)
+; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <4 x float> [[V]]
+;
+ %x.insert = insertelement <4 x float> poison, float %x, i32 0
+ %y.insert = insertelement <4 x float> poison, float %y, i32 0
+ %z.insert = insertelement <4 x float> poison, float %z, i32 0
+ %v = call <4 x float> @llvm.fma(<4 x float> %x.insert, <4 x float> %y.insert, <4 x float> %z.insert)
+ ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
+; CHECK-LABEL: define <vscale x 4 x float> @fma_scalable(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> poison, <vscale x 4 x float> poison, <vscale x 4 x float> poison)
+; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
+;
+ %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+ %y.insert = insertelement <vscale x 4 x float> poison, float %y, i32 0
+ %z.insert = insertelement <vscale x 4 x float> poison, float %z, i32 0
+ %v = call <vscale x 4 x float> @llvm.fma(<vscale x 4 x float> %x.insert, <vscale x 4 x float> %y.insert, <vscale x 4 x float> %z.insert)
+ ret <vscale x 4 x float> %v
+}
+
; TODO: We should be able to scalarize this if we preserve the scalar argument.
define <4 x float> @scalar_argument(float %x) {
; CHECK-LABEL: define <4 x float> @scalar_argument(
diff --git a/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
new file mode 100644
index 0000000000000..45d53c84c870d
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
@@ -0,0 +1,26 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -p vector-combine | FileCheck %s
+
+define <4 x float> @fneg_fixed(float %x) {
+; CHECK-LABEL: define <4 x float> @fneg_fixed(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]]
+; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> poison, float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <4 x float> [[V]]
+;
+ %x.insert = insertelement <4 x float> poison, float %x, i32 0
+ %v = fneg <4 x float> %x.insert
+ ret <4 x float> %v
+}
+
+define <vscale x 4 x float> @fneg_scalable(float %x) {
+; CHECK-LABEL: define <vscale x 4 x float> @fneg_scalable(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]]
+; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> poison, float [[V_SCALAR]], i64 0
+; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
+;
+ %x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
+ %v = fneg <vscale x 4 x float> %x.insert
+ ret <vscale x 4 x float> %v
+}
More information about the llvm-commits
mailing list