[llvm] [RISCV] Add RISCVISD opcodes for PSHL/PSRL/PSRA and lower to them. (PR #184836)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 5 11:00:49 PST 2026


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/184836

>From 0fb39487637e231edc66dea1fe8d5629e6c97ec4 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 5 Mar 2026 09:36:40 -0800
Subject: [PATCH] [RISCV] Add RISCVISD opcodes for PSHL/PSRL/PSRA and lower to
 them.

We only support splat shift amounts. Previously we checked if
the shift amount was a splat_vector and considered it legal.

I don't think there is a guarantee that the splat_vector will
stick around as a splat_vector. It's safer if we capture the splat
and create a dedicated node with a scalar shift amount.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 ++++++--
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td    | 71 +++++++++++----------
 2 files changed, 55 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 74b70b6642cd9..d13ad9535b186 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8857,14 +8857,27 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::SRA:
     if (Op.getSimpleValueType().isFixedLengthVector()) {
       if (Subtarget.hasStdExtP()) {
-        // We have patterns for scalar/immediate shift amount, so no lowering
-        // needed.
-        if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR)
-          return Op;
-
         // There's no vector-vector version of shift instruction in P extension
         // so we need to unroll to scalar computation and pack them back.
-        return DAG.UnrollVectorOp(Op.getNode());
+        if (Op.getOperand(1)->getOpcode() != ISD::SPLAT_VECTOR)
+          return DAG.UnrollVectorOp(Op.getNode());
+
+        unsigned Opc;
+        switch (Op.getOpcode()) {
+        default:
+          llvm_unreachable("Unexpected opcode");
+        case ISD::SHL:
+          Opc = RISCVISD::PSHL;
+          break;
+        case ISD::SRL:
+          Opc = RISCVISD::PSRL;
+          break;
+        case ISD::SRA:
+          Opc = RISCVISD::PSRA;
+          break;
+        }
+        return DAG.getNode(Opc, SDLoc(Op), Op.getValueType(), Op.getOperand(0),
+                           Op.getOperand(1).getOperand(0));
       }
       return lowerToScalableOp(Op, DAG);
     }
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index f82ff91eecdb3..5def40a6b168b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1658,6 +1658,13 @@ def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>;
 def riscv_mulhru : RVSDNode<"MULHRU", SDTIntBinOp>;
 def riscv_mulhrsu : RVSDNode<"MULHRSU", SDTIntBinOp>;
 
+def STD_RISCVPackedShift : SDTypeProfile<1, 2, [SDTCisVec<0>,
+                                                SDTCisSameAs<0, 1>,
+                                                SDTCisVT<2, XLenVT>]>;
+def riscv_pshl : RVSDNode<"PSHL", STD_RISCVPackedShift>;
+def riscv_psrl : RVSDNode<"PSRL", STD_RISCVPackedShift>;
+def riscv_psra : RVSDNode<"PSRA", STD_RISCVPackedShift>;
+
 // Bitwise merge: res = (~op0 & op1) | (op0 & op2)
 def SDT_RISCVMERGE : SDTypeProfile<1, 3, [SDTCisInt<0>,
                                           SDTCisSameAs<0, 1>,
@@ -1766,23 +1773,23 @@ let Predicates = [HasStdExtP] in {
   def: Pat<(XLenVecI16VT (riscv_mulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
 
   // 8-bit logical shift left/right patterns
-  def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+  def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, uimm3:$shamt)),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
-  def: Pat<(XLenVecI8VT (srl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+  def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, uimm3:$shamt)),
            (PSRLI_B GPR:$rs1, uimm3:$shamt)>;
 
   // 16-bit logical shift left/right patterns
-  def: Pat<(XLenVecI16VT (shl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+  def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, uimm4:$shamt)),
            (PSLLI_H GPR:$rs1, uimm4:$shamt)>;
-  def: Pat<(XLenVecI16VT (srl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+  def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, uimm4:$shamt)),
            (PSRLI_H GPR:$rs1, uimm4:$shamt)>;
 
   // 8-bit arithmetic shift right patterns
-  def: Pat<(XLenVecI8VT (sra GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
+  def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, uimm3:$shamt)),
            (PSRAI_B GPR:$rs1, uimm3:$shamt)>;
 
   // 16-bit arithmetic shift right patterns
-  def: Pat<(XLenVecI16VT (sra GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
+  def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, uimm4:$shamt)),
            (PSRAI_H GPR:$rs1, uimm4:$shamt)>;
 
   // 16-bit signed saturation shift left patterns
@@ -1790,29 +1797,23 @@ let Predicates = [HasStdExtP] in {
            (PSSLAI_H GPR:$rs1, uimm4:$shamt)>;
 
   // 8-bit logical shift left/right
-  def: Pat<(XLenVecI8VT (shl GPR:$rs1,
-                             (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI8VT (riscv_pshl GPR:$rs1, GPR:$rs2)),
            (PSLL_BS GPR:$rs1, GPR:$rs2)>;
-  def: Pat<(XLenVecI8VT (srl GPR:$rs1,
-                             (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI8VT (riscv_psrl GPR:$rs1, GPR:$rs2)),
            (PSRL_BS GPR:$rs1, GPR:$rs2)>;
 
   // 8-bit arithmetic shift left/right
-  def: Pat<(XLenVecI8VT (sra GPR:$rs1,
-                             (XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI8VT (riscv_psra GPR:$rs1, GPR:$rs2)),
            (PSRA_BS GPR:$rs1, GPR:$rs2)>;
 
   // 16-bit logical shift left/right
-  def: Pat<(XLenVecI16VT (shl GPR:$rs1,
-                              (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI16VT (riscv_pshl GPR:$rs1, GPR:$rs2)),
            (PSLL_HS GPR:$rs1, GPR:$rs2)>;
-  def: Pat<(XLenVecI16VT (srl GPR:$rs1,
-                              (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI16VT (riscv_psrl GPR:$rs1, GPR:$rs2)),
            (PSRL_HS GPR:$rs1, GPR:$rs2)>;
 
   // 16-bit arithmetic shift left/right
-  def: Pat<(XLenVecI16VT (sra GPR:$rs1,
-                              (XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(XLenVecI16VT (riscv_psra GPR:$rs1, GPR:$rs2)),
            (PSRA_HS GPR:$rs1, GPR:$rs2)>;
 
   // 8-bit PLI SD node pattern
@@ -1972,14 +1973,28 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (mul GPR:$rs1, GPR:$rs2)),
            (PACK (MUL_W00 GPR:$rs1, GPR:$rs2), (MUL_W11 GPR:$rs1, GPR:$rs2))>;
 
+  // 32-bit logical shift left/right patterns
+  def: Pat<(v2i32 (riscv_pshl GPR:$rs1, uimm5:$shamt)),
+           (PSLLI_W GPR:$rs1, uimm5:$shamt)>;
+  def: Pat<(v2i32 (riscv_psrl GPR:$rs1, uimm5:$shamt)),
+           (PSRLI_W GPR:$rs1, uimm5:$shamt)>;
+
+  // 32-bit arithmetic shift left/right patterns
+  def: Pat<(v2i32 (riscv_psra GPR:$rs1, uimm5:$shamt)),
+           (PSRAI_W GPR:$rs1, uimm5:$shamt)>;
+
+  // 32-bit signed saturation shift left patterns
+  def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
+           (PSSLAI_W GPR:$rs1, uimm5:$shamt)>;
+
   // 32-bit logical shift left/right
-  def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(v2i32 (riscv_pshl GPR:$rs1, GPR:$rs2)),
            (PSLL_WS GPR:$rs1, GPR:$rs2)>;
-  def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(v2i32 (riscv_psrl GPR:$rs1, GPR:$rs2)),
            (PSRL_WS GPR:$rs1, GPR:$rs2)>;
 
   // 32-bit arithmetic shift left/right
-  def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
+  def: Pat<(v2i32 (riscv_psra GPR:$rs1, GPR:$rs2)),
            (PSRA_WS GPR:$rs1, GPR:$rs2)>;
 
   // splat pattern
@@ -2006,20 +2021,6 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (smax GPR:$rs1, GPR:$rs2)), (PMAX_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (umax GPR:$rs1, GPR:$rs2)), (PMAXU_W GPR:$rs1, GPR:$rs2)>;
 
-  // 32-bit logical shift left/right patterns
-  def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
-           (PSLLI_W GPR:$rs1, uimm5:$shamt)>;
-  def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
-           (PSRLI_W GPR:$rs1, uimm5:$shamt)>;
-
-  // 32-bit arithmetic shift left/right patterns
-  def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
-           (PSRAI_W GPR:$rs1, uimm5:$shamt)>;
-
-  // 32-bit signed saturation shift left patterns
-  def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
-           (PSSLAI_W GPR:$rs1, uimm5:$shamt)>;
-
   // 32-bit vselect patterns
   def: Pat<(v2i32 (vselect (v2i32 GPR:$mask), GPR:$true_v, GPR:$false_v)),
            (MERGE GPR:$mask, GPR:$false_v, GPR:$true_v)>;



More information about the llvm-commits mailing list