[llvm] AMDGPU: Improve getShuffleCost accuracy for 8- and 16-bit shuffles (PR #168818)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 20 09:45:22 PST 2025


Nicolai =?utf-8?q?Hähnle?= <nicolai.haehnle at amd.com>,
Nicolai =?utf-8?q?Hähnle?= <nicolai.haehnle at amd.com>,
Nicolai =?utf-8?q?Hähnle?= <nicolai.haehnle at amd.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/168818 at github.com>


================
@@ -1241,46 +1241,108 @@ InstructionCost GCNTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
       (ScalarSize == 16 || ScalarSize == 8)) {
     // Larger vector widths may require additional instructions, but are
     // typically cheaper than scalarized versions.
-    unsigned NumVectorElts = cast<FixedVectorType>(SrcTy)->getNumElements();
-    unsigned RequestedElts =
-        count_if(Mask, [](int MaskElt) { return MaskElt != -1; });
-    unsigned EltsPerReg = 32 / ScalarSize;
-    if (RequestedElts == 0)
+    //
+    // We assume that shuffling at a register granularity can be done for free.
+    // This is not true for vectors fed into memory instructions, but it is
+    // effectively true for all other shuffling. The emphasis of the logic here
+    // is to assist generic transform in cleaning up / canonicalizing those
+    // shuffles.
+    unsigned NumDstElts = cast<FixedVectorType>(DstTy)->getNumElements();
+    unsigned NumSrcElts = cast<FixedVectorType>(SrcTy)->getNumElements();
+
+    // With op_sel VOP3P instructions freely can access the low half or high
+    // half of a register, so any swizzle of two elements is free.
+    if (ST->hasVOP3PInsts() && ScalarSize == 16 && NumSrcElts == 2 &&
+        (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Reverse ||
+         Kind == TTI::SK_PermuteSingleSrc))
       return 0;
+
+    unsigned EltsPerReg = 32 / ScalarSize;
     switch (Kind) {
     case TTI::SK_Broadcast:
+      // A single v_perm_b32 can be re-used for all destination registers.
+      return 1;
     case TTI::SK_Reverse:
-    case TTI::SK_PermuteSingleSrc: {
-      // With op_sel VOP3P instructions freely can access the low half or high
-      // half of a register, so any swizzle of two elements is free.
-      if (ST->hasVOP3PInsts() && ScalarSize == 16 && NumVectorElts == 2)
-        return 0;
-      unsigned NumPerms = alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
-      // SK_Broadcast just reuses the same mask
-      unsigned NumPermMasks = Kind == TTI::SK_Broadcast ? 1 : NumPerms;
-      return NumPerms + NumPermMasks;
-    }
+      // One instruction per register.
+      return divideCeil(NumDstElts, EltsPerReg);
     case TTI::SK_ExtractSubvector:
+      if (Index % EltsPerReg == 0)
+        return 0; // Shuffling at register granularity
+      return divideCeil(NumDstElts, EltsPerReg);
     case TTI::SK_InsertSubvector: {
-      // Even aligned accesses are free
-      if (!(Index % 2))
-        return 0;
-      // Insert/extract subvectors only require shifts / extract code to get the
-      // relevant bits
-      return alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
+      unsigned NumInsertElts = cast<FixedVectorType>(SubTp)->getNumElements();
+      unsigned EndIndex = Index + NumInsertElts;
+      unsigned BeginSubIdx = Index % EltsPerReg;
+      unsigned EndSubIdx = EndIndex % EltsPerReg;
+      unsigned Cost = 0;
+
+      if (BeginSubIdx != 0) {
+        // Need to shift the inserted vector into place. The cost is the number
+        // of destination registers overlapped by the inserted vector.
+        Cost = divideCeil(EndIndex, EltsPerReg) - (Index / EltsPerReg);
+      }
+
+      // If the last register overlap is partial, there may be three source
+      // registers feeding into it; that takes an extra instruction.
+      if (EndIndex < NumDstElts && BeginSubIdx < EndSubIdx)
+        Cost += 1;
+
+      return Cost;
     }
-    case TTI::SK_PermuteTwoSrc:
-    case TTI::SK_Splice:
-    case TTI::SK_Select: {
-      unsigned NumPerms = alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
-      // SK_Select just reuses the same mask
-      unsigned NumPermMasks = Kind == TTI::SK_Select ? 1 : NumPerms;
-      return NumPerms + NumPermMasks;
+    case TTI::SK_Splice: {
+      assert(NumDstElts == NumSrcElts);
+      // Determine the sub-region of the result vector that requires
+      // sub-register shuffles / mixing.
+      unsigned EltsFromLHS = NumSrcElts - Index;
+      bool LHSIsAligned = (Index % EltsPerReg) == 0;
+      bool RHSIsAligned = (EltsFromLHS % EltsPerReg) == 0;
+      if (LHSIsAligned && RHSIsAligned)
+        return 0;
+      if (LHSIsAligned && !RHSIsAligned)
+        return divideCeil(NumDstElts - EltsFromLHS, EltsPerReg);
+      if (!LHSIsAligned && RHSIsAligned)
+        return divideCeil(EltsFromLHS, EltsPerReg);
+      return divideCeil(NumDstElts, EltsPerReg);
     }
-
     default:
       break;
     }
+
+    if (!Mask.empty()) {
+      // Generically estimate the cost by assuming that each destination
+      // register is derived from sources via v_perm_b32 instructions if it
+      // can't be copied as-is.
+      //
+      // For each destination register, derive the cost of obtaining it based
+      // on the number of source registers that feed into it.
+      unsigned Cost = 0;
+      for (unsigned DstIdx = 0; DstIdx < Mask.size(); DstIdx += EltsPerReg) {
+        SmallVector<int, 4> Regs;
+        bool Aligned = true;
+        for (unsigned I = 0; I < EltsPerReg && DstIdx + I < Mask.size(); ++I) {
+          int SrcIdx = Mask[DstIdx + I];
+          if (SrcIdx == -1)
+            continue;
+          int Reg;
+          if (SrcIdx < (int)NumSrcElts) {
+            Reg = SrcIdx / EltsPerReg;
+            if (SrcIdx % EltsPerReg != I)
+              Aligned = false;
+          } else {
+            Reg = NumSrcElts + (SrcIdx - NumSrcElts) / EltsPerReg;
+            if ((SrcIdx - NumSrcElts) % EltsPerReg != I)
+              Aligned = false;
+          }
+          if (!llvm::is_contained(Regs, Reg))
+            Regs.push_back(Reg);
+        }
+        if (Regs.size() >= 2)
+          Cost += Regs.size() - 1;
+        else if (!Aligned)
+          Cost += 1;
----------------
jayfoad wrote:

What is this accounting for?

https://github.com/llvm/llvm-project/pull/168818


More information about the llvm-commits mailing list