[llvm] 3e5fafd - [RISCV][llvm] Select splat_vector(constant) with PLI (#168204)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 19 23:02:44 PST 2025


Author: Brandon Wu
Date: 2025-11-20T15:02:40+08:00
New Revision: 3e5fafdc223a937f371d22dc05d4ab8398b13f3f

URL: https://github.com/llvm/llvm-project/commit/3e5fafdc223a937f371d22dc05d4ab8398b13f3f
DIFF: https://github.com/llvm/llvm-project/commit/3e5fafdc223a937f371d22dc05d4ab8398b13f3f.diff

LOG: [RISCV][llvm] Select splat_vector(constant) with PLI (#168204)

Default DAG combiner combine BUILD_VECTOR with same elements to
SPLAT_VECTOR, we can just map constant splat to PLI if possible.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoP.td
    llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
    llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 1024e55f912c7..5025122db3681 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -51,6 +51,8 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
     SDValue Result;
     switch (N->getOpcode()) {
     case ISD::SPLAT_VECTOR: {
+      if (Subtarget->enablePExtCodeGen())
+        break;
       // Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
       // SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
       MVT VT = N->getSimpleValueType(0);

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3fbfdfb565e53..6020fb6ca16ce 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -526,7 +526,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::SSUBSAT, VTs, Legal);
     setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
     setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
-    setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
+    setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
     setOperationAction(ISD::BITCAST, VTs, Custom);
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
   }
@@ -4433,37 +4433,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
   MVT XLenVT = Subtarget.getXLenVT();
 
   SDLoc DL(Op);
-  // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
-  if (Subtarget.enablePExtCodeGen()) {
-    bool IsPExtVector =
-        (VT == MVT::v2i16 || VT == MVT::v4i8) ||
-        (Subtarget.is64Bit() &&
-         (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
-    if (IsPExtVector) {
-      if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
-        if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
-          int64_t SplatImm = C->getSExtValue();
-          bool IsValidImm = false;
-
-          // Check immediate range based on vector type
-          if (VT == MVT::v8i8 || VT == MVT::v4i8) {
-            // PLI_B uses 8-bit unsigned or unsigned immediate
-            IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm);
-            if (isUInt<8>(SplatImm))
-              SplatImm = (int8_t)SplatImm;
-          } else {
-            // PLI_H and PLI_W use 10-bit signed immediate
-            IsValidImm = isInt<10>(SplatImm);
-          }
-
-          if (IsValidImm) {
-            SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT);
-            return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
-          }
-        }
-      }
-    }
-  }
 
   // Proper support for f16 requires Zvfh. bf16 always requires special
   // handling. We need to cast the scalar to integer and create an integer

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 126a39996c741..764e3c9c58355 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -18,7 +18,7 @@
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
 
-def simm10 : RISCVSImmOp<10>, TImmLeaf<XLenVT, "return isInt<10>(Imm);">;
+def simm10 : RISCVSImmOp<10>, ImmLeaf<XLenVT, "return isInt<10>(Imm);">;
 
 def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
   let RenderMethod = "addSImm8UnsignedOperands";
@@ -26,7 +26,7 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
 
 // A 8-bit signed immediate allowing range [-128, 255]
 // but represented as [-128, 127].
-def simm8_unsigned : RISCVOp, TImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
+def simm8_unsigned : RISCVOp, ImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
   let ParserMatchClass = SImm8UnsignedAsmOperand;
   let EncoderMethod = "getImmOpValue";
   let DecoderMethod = "decodeSImmOperand<8>";
@@ -1463,10 +1463,6 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>;
 
-def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisVec<0>,
-                                        SDTCisInt<0>,
-                                        SDTCisInt<1>]>;
-def riscv_pli : RVSDNode<"PLI", SDT_RISCVPLI>;
 def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
                                           SDTCisInt<0>,
                                           SDTCisSameAs<0, 1>,
@@ -1519,10 +1515,13 @@ let Predicates = [HasStdExtP] in {
   
 
   // 8-bit PLI SD node pattern
-  def: Pat<(XLenVecI8VT (riscv_pli simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
+  def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
   // 16-bit PLI SD node pattern
-  def: Pat<(XLenVecI16VT (riscv_pli simm10:$imm10)), (PLI_H simm10:$imm10)>;
+  def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>;
 
+  // // splat pattern
+  def: Pat<(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))), (PADD_BS (XLenVT X0), GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))), (PADD_HS (XLenVT X0), GPR:$rs2)>;
 } // Predicates = [HasStdExtP]
 
 let Predicates = [HasStdExtP, IsRV32] in {
@@ -1537,7 +1536,7 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def : PatGpr<riscv_absw, ABSW>;
 
   // 32-bit PLI SD node pattern
-  def: Pat<(v2i32 (riscv_pli simm10:$imm10)), (PLI_W simm10:$imm10)>;
+  def: Pat<(v2i32 (splat_vector simm10:$imm10)), (PLI_W simm10:$imm10)>;
 
   // Basic 32-bit arithmetic patterns
   def: Pat<(v2i32 (add GPR:$rs1, GPR:$rs2)), (PADD_W GPR:$rs1, GPR:$rs2)>;
@@ -1557,6 +1556,9 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // splat pattern
+  def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
+
   // Load/Store patterns
   def : StPat<store, SD, GPR, v8i8>;
   def : StPat<store, SD, GPR, v4i16>;

diff  --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index 46d5e9f9a538f..bb3e691311cd8 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -496,6 +496,33 @@ define void @test_extract_vector_8(ptr %ret_ptr, ptr %a_ptr) {
   ret void
 }
 
+; Test for splat
+define void @test_non_const_splat_i8(ptr %ret_ptr, ptr %a_ptr, i8 %elt) {
+; CHECK-LABEL: test_non_const_splat_i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.bs a1, zero, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i8>, ptr %a_ptr
+  %insert = insertelement <4 x i8> poison, i8 %elt, i32 0
+  %splat = shufflevector <4 x i8> %insert, <4 x i8> poison, <4 x i32> zeroinitializer
+  store <4 x i8> %splat, ptr %ret_ptr
+  ret void
+}
+
+define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) {
+; CHECK-LABEL: test_non_const_splat_i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.hs a1, zero, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %insert = insertelement <2 x i16> poison, i16 %elt, i32 0
+  %splat = shufflevector <2 x i16> %insert, <2 x i16> poison, <2 x i32> zeroinitializer
+  store <2 x i16> %splat, ptr %ret_ptr
+  ret void
+}
+
 ; Intrinsic declarations
 declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>)
 declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>)

diff  --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 353039e9482e9..f989b025a12dc 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -671,6 +671,20 @@ define void @test_pasubu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   ret void
 }
 
+; Test for splat
+define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) {
+; CHECK-LABEL: test_non_const_splat_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.ws a1, zero, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %insert = insertelement <2 x i32> poison, i32 %elt, i32 0
+  %splat = shufflevector <2 x i32> %insert, <2 x i32> poison, <2 x i32> zeroinitializer
+  store <2 x i32> %splat, ptr %ret_ptr
+  ret void
+}
+
 ; Intrinsic declarations
 declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>)
 declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)


        


More information about the llvm-commits mailing list