[llvm] [VectorCombine][RISCV] Convert VPIntrinsics with splat operands to splats (PR #65706)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 16 12:23:30 PDT 2023


================
@@ -729,6 +730,112 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
   return true;
 }
 
+/// VP Intrinsics whose vector operands are both splat values may be simplified
+/// into the scalar version of the operation and the result splatted. This
+/// can lead to scalarization down the line.
+bool VectorCombine::scalarizeVPIntrinsic(VPIntrinsic &VPI) {
+  Value *Op0 = VPI.getArgOperand(0);
+  Value *Op1 = VPI.getArgOperand(1);
+
+  if (!isSplatValue(Op0) || !isSplatValue(Op1))
+    return false;
+
+  // For the binary VP intrinsics supported here, the result on disabled lanes
+  // is a poison value. For now, only do this simplification if all lanes
+  // are active.
+  // TODO: Relax the condition that all lanes are active by using insertelement
+  // on inactive lanes.
+  auto IsAllTrueMask = [](Value *MaskVal) {
+    if (Value *SplattedVal = getSplatValue(MaskVal))
+      if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
+        return ConstValue->isAllOnesValue();
+    return false;
+  };
+  if (!IsAllTrueMask(VPI.getArgOperand(2)))
+    return false;
+
+  // Check to make sure we support scalarization of the intrinsic
+  Intrinsic::ID IntrID = VPI.getIntrinsicID();
+  if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
+    return false;
+
+  // Calculate cost of splatting both operands into vectors and the vector
+  // intrinsic
+  VectorType *VecTy = cast<VectorType>(VPI.getType());
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost SplatCost =
+      TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);
+
+  // Calculate the cost of the VP Intrinsic
+  SmallVector<Type *, 4> Args;
+  for (Value *V : VPI.args())
+    Args.push_back(V->getType());
+  IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
+  InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
+  InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
+
+  // Determine scalar opcode
+  std::optional<unsigned> FunctionalOpcode =
+      VPI.getFunctionalOpcode();
+  bool ScalarIsIntr = false;
+  std::optional<Intrinsic::ID> ScalarIntrID;
+  if (!FunctionalOpcode) {
+    ScalarIntrID = VPI.getFunctionalIntrinsicID();
+    if (!ScalarIntrID)
+      return false;
+    ScalarIsIntr = true;
+  }
+
+  // Calculate cost of scalarizing
+  InstructionCost ScalarOpCost = 0;
+  if (ScalarIsIntr) {
+    IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
+    ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
+  } else {
+    ScalarOpCost =
+        TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
+  }
+
+  // The existing splats may be kept around if other instructions use them.
+  InstructionCost CostToKeepSplats =
+      (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
+  InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
+
+  LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
+                    << "\n");
+  LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
+                    << ", Cost of scalarizing:" << NewCost << "\n");
+
+  // We want to scalarize unless the vector variant actually has lower cost.
+  if (OldCost < NewCost || !NewCost.isValid())
+    return false;
+
+  // Scalarize the intrinsic
+  ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
+  Value *EVL = VPI.getArgOperand(3);
+  const DataLayout &DL = VPI.getModule()->getDataLayout();
+  bool IsKnownNonZeroVL = isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT);
+  bool MustHaveNonZeroVL =
+      IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
+      IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem ||
+      IntrID == Intrinsic::vp_fdiv || IntrID  == Intrinsic::vp_frem;
+
+  if ((MustHaveNonZeroVL && IsKnownNonZeroVL) || !MustHaveNonZeroVL) {
----------------
topperc wrote:

This is just `!MustHaveNonZeroVL || IsKnownNonZeroVL`

https://github.com/llvm/llvm-project/pull/65706


More information about the llvm-commits mailing list