[llvm] [InstCombine][RISCV] Convert VPIntrinsics with splat operands to splats (PR #65706)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 11 10:03:00 PDT 2023
================
@@ -729,6 +730,172 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
return true;
}
+/// VP Intrinsics whose vector operands are both splat values may be simplified
+/// into the scalar version of the operation and the result is splatted. This
+/// can lead to scalarization down the line.
+bool VectorCombine::scalarizeVPIntrinsic(VPIntrinsic &VPI) {
+ Value *Op0 = VPI.getArgOperand(0);
+ Value *Op1 = VPI.getArgOperand(1);
+
+ if (!isSplatValue(Op0) || !isSplatValue(Op1))
+ return false;
+
+ // For the binary VP intrinsics supported here, the result on disabled lanes
+ // is a poison value. For now, only do this simplification if all lanes
+ // are active.
+ // TODO: Relax the condition that all lanes are active by using insertelement
+ // on inactive lanes.
+ auto IsAllTrueMask = [](Value *MaskVal) {
+ if (Value *SplattedVal = getSplatValue(MaskVal))
+ if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
+ return ConstValue->isAllOnesValue();
+ return false;
+ };
+ if (!IsAllTrueMask(VPI.getArgOperand(2)))
+ return false;
+
+ DenseMap<Intrinsic::ID, unsigned> VPIntrinsicToScalar(
+ {{Intrinsic::vp_add, Instruction::Add},
+ {Intrinsic::vp_sub, Instruction::Sub},
+ {Intrinsic::vp_mul, Instruction::Mul},
+ {Intrinsic::vp_ashr, Instruction::AShr},
+ {Intrinsic::vp_lshr, Instruction::LShr},
+ {Intrinsic::vp_shl, Instruction::Shl},
+ {Intrinsic::vp_or, Instruction::Or},
+ {Intrinsic::vp_and, Instruction::And},
+ {Intrinsic::vp_xor, Instruction::Xor},
+ {Intrinsic::vp_fadd, Instruction::FAdd},
+ {Intrinsic::vp_fsub, Instruction::FSub},
+ {Intrinsic::vp_fmul, Instruction::FMul},
+ {Intrinsic::vp_sdiv, Instruction::SDiv},
+ {Intrinsic::vp_udiv, Instruction::UDiv},
+ {Intrinsic::vp_srem, Instruction::SRem},
+ {Intrinsic::vp_urem, Instruction::URem}});
+
+ // Check to make sure we support scalarization of the intrinsic
+ Intrinsic::ID IntrID = VPI.getIntrinsicID();
+ if (!VPIntrinsicToScalar.contains(IntrID))
+ return false;
+
+ // Calculate cost of splatting both operands into vectors and the vector
+ // intrinsic
+ VectorType *VecTy = cast<VectorType>(VPI.getType());
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost SplatCost =
+ TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
+ TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);
+
+ // Calculate the cost of the VP Intrinsic
+ SmallVector<Type *, 4> Args;
+ for (Value *V : VPI.args())
+ Args.push_back(V->getType());
+ IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
+ InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
+ InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
+
+ // Calculate cost of scalarizing
+ InstructionCost ScalarOpCost = TTI.getArithmeticInstrCost(
+ VPIntrinsicToScalar[IntrID], VecTy->getScalarType());
+ InstructionCost NewCost = ScalarOpCost + SplatCost;
+
+ LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
+ << "\n");
+ LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
+ << ", Cost of scalarizing:" << NewCost << "\n");
+
+ // We want to scalarize unless the vector variant actually has lower cost.
+ if (OldCost < NewCost || !NewCost.isValid())
+ return false;
+
+ // Scalarize the intrinsic
+ ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
+ Value *EVL = VPI.getArgOperand(3);
+ const DataLayout &DL = VPI.getModule()->getDataLayout();
+ switch (IntrID) {
+ case Intrinsic::vp_add:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateAdd(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_sub:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateSub(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_mul:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateMul(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_ashr:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateAShr(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_lshr:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateLShr(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_shl:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateShl(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_or:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateOr(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_and:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateAnd(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_xor:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateXor(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_fadd:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateFAdd(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_fsub:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateFSub(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
+ case Intrinsic::vp_fmul:
+ replaceValue(VPI, *Builder.CreateVectorSplat(
+ EC, Builder.CreateFMul(getSplatValue(Op0), getSplatValue(Op1))));
+ return true;
----------------
lukel97 wrote:
Can we replace this switch statement with a single `Builder->CreateBinOp(VPIntrinsicToScalar[IntrID], getSplatValue(Op0), getSplatValue(Op1))`? Checking above if it's a div/rem and returning early
https://github.com/llvm/llvm-project/pull/65706
More information about the llvm-commits
mailing list