[llvm] [RISCV][TTI] Refine reverse shuffle costing for high LMUL (PR #144155)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 09:57:32 PDT 2025


================
@@ -840,33 +849,64 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
     return LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind);
   }
   case TTI::SK_Reverse: {
+
+    if (!LT.second.isVector())
+      return InstructionCost::getInvalid();
+
     // TODO: Cases to improve here:
     // * Illegal vector types
     // * i64 on RV32
-    // * i1 vector
-    // At low LMUL, most of the cost is producing the vrgather index register.
-    // At high LMUL, the cost of the vrgather itself will dominate.
-    // Example sequence:
-    //   csrr a0, vlenb
-    //   srli a0, a0, 3
-    //   addi a0, a0, -1
-    //   vsetvli a1, zero, e8, mf8, ta, mu (ignored)
-    //   vid.v v9
-    //   vrsub.vx v10, v9, a0
-    //   vrgather.vv v9, v8, v10
-    InstructionCost LenCost = 3;
+    if (Tp->getElementType()->isIntegerTy(1)) {
+      VectorType *WideTy =
+          VectorType::get(IntegerType::get(Tp->getContext(), 8),
+                          cast<VectorType>(Tp)->getElementCount());
+      return getCastInstrCost(Instruction::ZExt, WideTy, Tp,
+                              TTI::CastContextHint::None, CostKind) +
+             getShuffleCost(TTI::SK_Reverse, WideTy, {}, CostKind, 0, nullptr) +
+             getCastInstrCost(Instruction::Trunc, Tp, WideTy,
+                              TTI::CastContextHint::None, CostKind);
+    }
+
+    MVT ContainerVT = LT.second;
     if (LT.second.isFixedLengthVector())
-      // vrsub.vi has a 5 bit immediate field, otherwise an li suffices
-      LenCost = isInt<5>(LT.second.getVectorNumElements() - 1) ? 0 : 1;
-    unsigned Opcodes[] = {RISCV::VID_V, RISCV::VRSUB_VX, RISCV::VRGATHER_VV};
-    if (LT.second.isFixedLengthVector() &&
-        isInt<5>(LT.second.getVectorNumElements() - 1))
-      Opcodes[1] = RISCV::VRSUB_VI;
+      ContainerVT = TLI->getContainerForFixedLengthVector(LT.second);
+    MVT M1VT = getLMUL1VT(ContainerVT);
+    if (ContainerVT.bitsLE(M1VT)) {
+      // Example sequence:
+      //   csrr a0, vlenb
+      //   srli a0, a0, 3
+      //   addi a0, a0, -1
+      //   vsetvli a1, zero, e8, mf8, ta, mu (ignored)
+      //   vid.v v9
----------------
preames wrote:

You're right, the code is using a bound on VLMAX in that case.   The actual lowering also narrows the index to 16 bits if possible.  I'm okay doing this as a followup, but would prefer to leave that out of this change.  Are you okay with that?

https://github.com/llvm/llvm-project/pull/144155


More information about the llvm-commits mailing list