[llvm] [NVPTX] Customize getScalarizationOverhead (PR #128077)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 11:48:34 PST 2025


================
@@ -100,6 +101,42 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
 
+  InstructionCost getScalarizationOverhead(VectorType *InTy,
+                                           const APInt &DemandedElts,
+                                           bool Insert, bool Extract,
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) {
+    if (!InTy->getElementCount().isFixed())
+      return InstructionCost::getInvalid();
+
+    auto VT = getTLI()->getValueType(DL, InTy);
+    auto NumElements = InTy->getElementCount().getFixedValue();
+    InstructionCost Cost = 0;
+    if (Insert && !VL.empty()) {
+      bool AllConstant = all_of(seq(NumElements), [&](int Idx) {
+        return !DemandedElts[Idx] || isa<Constant>(VL[Idx]);
+      });
+      if (AllConstant) {
+        Cost += TTI::TCC_Free;
+        Insert = false;
+      }
+    }
+    if (Insert && Isv2x16VT(VT)) {
+      // Can be built in a single mov
+      Cost += 1;
+      Insert = false;
+    }
+    if (Insert && VT == MVT::v4i8) {
+      InstructionCost Cost = 3; // 3 x PRMT
+      for (auto Idx : seq(NumElements))
+        if (DemandedElts[Idx])
+          Cost += 1; // zext operand to i32
----------------
peterbell10 wrote:

We zext the inputs to the 2 first `PRMT` ops, see the relevant lowering here:
https://github.com/llvm/llvm-project/blob/b5bbe4eef3823facf83e85d2c11a97ce01882ea2/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L2147-L2165

https://github.com/llvm/llvm-project/pull/128077


More information about the llvm-commits mailing list