[llvm] 21bca79 - [RISCV] Use switch in RISCVTargetTransformInfo::getShuffleCost [nfc]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 13 08:41:08 PDT 2023


Author: Philip Reames
Date: 2023-03-13T08:40:47-07:00
New Revision: 21bca796d7f06006037bc57a578131751474adea

URL: https://github.com/llvm/llvm-project/commit/21bca796d7f06006037bc57a578131751474adea
DIFF: https://github.com/llvm/llvm-project/commit/21bca796d7f06006037bc57a578131751474adea.diff

LOG: [RISCV] Use switch in RISCVTargetTransformInfo::getShuffleCost [nfc]

Refactoring in advance of a semantic change.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index b68080dc4b18e..ebf80a3e13e9e 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -257,8 +257,9 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
                                              TTI::TargetCostKind CostKind,
                                              int Index, VectorType *SubTp,
                                              ArrayRef<const Value *> Args) {
+  std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
+
   if (isa<ScalableVectorType>(Tp)) {
-    std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
     switch (Kind) {
     default:
       // Fallthrough to generic handling.
@@ -287,71 +288,73 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
     }
   }
 
-  if (isa<FixedVectorType>(Tp) && Kind == TargetTransformInfo::SK_Broadcast) {
-    std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
-    bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
-                                           Instruction::InsertElement);
-    if (LT.second.getScalarSizeInBits() == 1) {
+  if (isa<FixedVectorType>(Tp)) {
+    switch (Kind) {
+    default:
+      break;
+    case TargetTransformInfo::SK_Broadcast: {
+      bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
+                                             Instruction::InsertElement);
+      if (LT.second.getScalarSizeInBits() == 1) {
+        if (HasScalar) {
+          // Example sequence:
+          //   andi a0, a0, 1
+          //   vsetivli zero, 2, e8, mf8, ta, ma (ignored)
+          //   vmv.v.x v8, a0
+          //   vmsne.vi v0, v8, 0
+          return LT.first * getLMULCost(LT.second) * 3;
+        }
+        // Example sequence:
+        //   vsetivli  zero, 2, e8, mf8, ta, mu (ignored)
+        //   vmv.v.i v8, 0
+        //   vmerge.vim      v8, v8, 1, v0
+        //   vmv.x.s a0, v8
+        //   andi    a0, a0, 1
+        //   vmv.v.x v8, a0
+        //   vmsne.vi  v0, v8, 0
+
+        return LT.first * getLMULCost(LT.second) * 6;
+      }
+
       if (HasScalar) {
         // Example sequence:
-        //   andi a0, a0, 1
-        //   vsetivli zero, 2, e8, mf8, ta, ma (ignored)
         //   vmv.v.x v8, a0
-        //   vmsne.vi v0, v8, 0
-        return LT.first * getLMULCost(LT.second) * 3;
+        return LT.first * getLMULCost(LT.second);
       }
-      // Example sequence:
-      //   vsetivli  zero, 2, e8, mf8, ta, mu (ignored)
-      //   vmv.v.i v8, 0
-      //   vmerge.vim      v8, v8, 1, v0
-      //   vmv.x.s a0, v8
-      //   andi    a0, a0, 1
-      //   vmv.v.x v8, a0
-      //   vmsne.vi  v0, v8, 0
-
-      return LT.first * getLMULCost(LT.second) * 6;
-    }
 
-    if (HasScalar) {
       // Example sequence:
-      //   vmv.v.x v8, a0
+      //   vrgather.vi     v9, v8, 0
+      // TODO: vrgather could be slower than vmv.v.x. It is
+      // implementation-dependent.
       return LT.first * getLMULCost(LT.second);
     }
-
-    // Example sequence:
-    //   vrgather.vi     v9, v8, 0
-    // TODO: vrgather could be slower than vmv.v.x. It is
-    // implementation-dependent.
-    return LT.first * getLMULCost(LT.second);
-  }
-
-  if (isa<FixedVectorType>(Tp) && Kind == TTI::SK_PermuteSingleSrc &&
-      Mask.size() >= 2) {
-    std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
-    if (LT.second.isFixedLengthVector()) {
-      MVT EltTp = LT.second.getVectorElementType();
-      // If the size of the element is < ELEN then shuffles of interleaves and
-      // deinterleaves of 2 vectors can be lowered into the following sequences
-      if (EltTp.getScalarSizeInBits() < ST->getELEN()) {
-        auto InterleaveMask = createInterleaveMask(Mask.size() / 2, 2);
-        // Example sequence:
-        //   vsetivli	zero, 4, e8, mf4, ta, ma (ignored)
-        //   vwaddu.vv	v10, v8, v9
-        //   li		a0, -1                   (ignored)
-        //   vwmaccu.vx	v10, a0, v9
-        if (equal(InterleaveMask, Mask))
-          return 2 * LT.first * getLMULCost(LT.second);
-
-        if (Mask[0] == 0 || Mask[0] == 1) {
-          auto DeinterleaveMask = createStrideMask(Mask[0], 2, Mask.size());
+    case TTI::SK_PermuteSingleSrc: {
+      if (Mask.size() >= 2 && LT.second.isFixedLengthVector()) {
+        MVT EltTp = LT.second.getVectorElementType();
+        // If the size of the element is < ELEN then shuffles of interleaves and
+        // deinterleaves of 2 vectors can be lowered into the following sequences
+        if (EltTp.getScalarSizeInBits() < ST->getELEN()) {
+          auto InterleaveMask = createInterleaveMask(Mask.size() / 2, 2);
           // Example sequence:
-          //   vnsrl.wi	v10, v8, 0
-          if (equal(DeinterleaveMask, Mask))
-            return LT.first * getLMULCost(LT.second);
+          //   vsetivli	zero, 4, e8, mf4, ta, ma (ignored)
+          //   vwaddu.vv	v10, v8, v9
+          //   li		a0, -1                   (ignored)
+          //   vwmaccu.vx	v10, a0, v9
+          if (equal(InterleaveMask, Mask))
+            return 2 * LT.first * getLMULCost(LT.second);
+
+          if (Mask[0] == 0 || Mask[0] == 1) {
+            auto DeinterleaveMask = createStrideMask(Mask[0], 2, Mask.size());
+            // Example sequence:
+            //   vnsrl.wi	v10, v8, 0
+            if (equal(DeinterleaveMask, Mask))
+              return LT.first * getLMULCost(LT.second);
+          }
         }
       }
     }
-  }
+    }
+  };
 
   return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
 }


        


More information about the llvm-commits mailing list