[llvm] [RISCV][llvm] Select splat_vector(constant) with PLI (PR #168204)

Brandon Wu via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 17 00:02:49 PST 2025


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/168204

>From 142db52bbbc33f608317dbd84a1650f2a0b356b9 Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Sat, 15 Nov 2025 04:18:35 -0800
Subject: [PATCH 1/2] [RISCV][llvm] Select splat_vector(constant) with PLI

Default DAG combiner combine BUILD_VECTOR with same elements to
SPLAT_VECTOR, we can just map constant splat to PLI if possible.
---
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp |  2 ++
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 33 +--------------------
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td    | 14 ++++-----
 3 files changed, 8 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 1024e55f912c7..5025122db3681 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -51,6 +51,8 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
     SDValue Result;
     switch (N->getOpcode()) {
     case ISD::SPLAT_VECTOR: {
+      if (Subtarget->enablePExtCodeGen())
+        break;
       // Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
       // SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
       MVT VT = N->getSimpleValueType(0);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 38cce26e44af4..37c8d7c045443 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -525,7 +525,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::SSUBSAT, VTs, Legal);
     setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
     setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
-    setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
+    setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
     setOperationAction(ISD::BITCAST, VTs, Custom);
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
   }
@@ -4437,37 +4437,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
   MVT XLenVT = Subtarget.getXLenVT();
 
   SDLoc DL(Op);
-  // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
-  if (Subtarget.enablePExtCodeGen()) {
-    bool IsPExtVector =
-        (VT == MVT::v2i16 || VT == MVT::v4i8) ||
-        (Subtarget.is64Bit() &&
-         (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
-    if (IsPExtVector) {
-      if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
-        if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
-          int64_t SplatImm = C->getSExtValue();
-          bool IsValidImm = false;
-
-          // Check immediate range based on vector type
-          if (VT == MVT::v8i8 || VT == MVT::v4i8) {
-            // PLI_B uses 8-bit unsigned or unsigned immediate
-            IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm);
-            if (isUInt<8>(SplatImm))
-              SplatImm = (int8_t)SplatImm;
-          } else {
-            // PLI_H and PLI_W use 10-bit signed immediate
-            IsValidImm = isInt<10>(SplatImm);
-          }
-
-          if (IsValidImm) {
-            SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT);
-            return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
-          }
-        }
-      }
-    }
-  }
 
   // Proper support for f16 requires Zvfh. bf16 always requires special
   // handling. We need to cast the scalar to integer and create an integer
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 126a39996c741..2f289f89e8859 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -18,7 +18,7 @@
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
 
-def simm10 : RISCVSImmOp<10>, TImmLeaf<XLenVT, "return isInt<10>(Imm);">;
+def simm10 : RISCVSImmOp<10>, ImmLeaf<XLenVT, "return isInt<10>(Imm);">;
 
 def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
   let RenderMethod = "addSImm8UnsignedOperands";
@@ -26,7 +26,7 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
 
 // A 8-bit signed immediate allowing range [-128, 255]
 // but represented as [-128, 127].
-def simm8_unsigned : RISCVOp, TImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
+def simm8_unsigned : RISCVOp, ImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
   let ParserMatchClass = SImm8UnsignedAsmOperand;
   let EncoderMethod = "getImmOpValue";
   let DecoderMethod = "decodeSImmOperand<8>";
@@ -1463,10 +1463,6 @@ let Predicates = [HasStdExtP, IsRV32] in {
 
 def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>;
 
-def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisVec<0>,
-                                        SDTCisInt<0>,
-                                        SDTCisInt<1>]>;
-def riscv_pli : RVSDNode<"PLI", SDT_RISCVPLI>;
 def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
                                           SDTCisInt<0>,
                                           SDTCisSameAs<0, 1>,
@@ -1519,9 +1515,9 @@ let Predicates = [HasStdExtP] in {
   
 
   // 8-bit PLI SD node pattern
-  def: Pat<(XLenVecI8VT (riscv_pli simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
+  def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
   // 16-bit PLI SD node pattern
-  def: Pat<(XLenVecI16VT (riscv_pli simm10:$imm10)), (PLI_H simm10:$imm10)>;
+  def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>;
 
 } // Predicates = [HasStdExtP]
 
@@ -1537,7 +1533,7 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def : PatGpr<riscv_absw, ABSW>;
 
   // 32-bit PLI SD node pattern
-  def: Pat<(v2i32 (riscv_pli simm10:$imm10)), (PLI_W simm10:$imm10)>;
+  def: Pat<(v2i32 (splat_vector simm10:$imm10)), (PLI_W simm10:$imm10)>;
 
   // Basic 32-bit arithmetic patterns
   def: Pat<(v2i32 (add GPR:$rs1, GPR:$rs2)), (PADD_W GPR:$rs1, GPR:$rs2)>;

>From b1843c75799188d754ca56796ecb83c72b48cb13 Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0813 at gmail.com>
Date: Mon, 17 Nov 2025 00:02:36 -0800
Subject: [PATCH 2/2] fixup! Add non-const splat pattern

---
 llvm/lib/Target/RISCV/RISCVInstrInfoP.td |  6 ++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll  | 27 ++++++++++++++++++++++++
 llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll  | 14 ++++++++++++
 3 files changed, 47 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 2f289f89e8859..764e3c9c58355 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1519,6 +1519,9 @@ let Predicates = [HasStdExtP] in {
   // 16-bit PLI SD node pattern
   def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>;
 
+  // // splat pattern
+  def: Pat<(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))), (PADD_BS (XLenVT X0), GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))), (PADD_HS (XLenVT X0), GPR:$rs2)>;
 } // Predicates = [HasStdExtP]
 
 let Predicates = [HasStdExtP, IsRV32] in {
@@ -1553,6 +1556,9 @@ let Predicates = [HasStdExtP, IsRV64] in {
   def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
 
+  // splat pattern
+  def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
+
   // Load/Store patterns
   def : StPat<store, SD, GPR, v8i8>;
   def : StPat<store, SD, GPR, v4i16>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index 46d5e9f9a538f..bb3e691311cd8 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -496,6 +496,33 @@ define void @test_extract_vector_8(ptr %ret_ptr, ptr %a_ptr) {
   ret void
 }
 
+; Test for splat
+define void @test_non_const_splat_i8(ptr %ret_ptr, ptr %a_ptr, i8 %elt) {
+; CHECK-LABEL: test_non_const_splat_i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.bs a1, zero, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i8>, ptr %a_ptr
+  %insert = insertelement <4 x i8> poison, i8 %elt, i32 0
+  %splat = shufflevector <4 x i8> %insert, <4 x i8> poison, <4 x i32> zeroinitializer
+  store <4 x i8> %splat, ptr %ret_ptr
+  ret void
+}
+
+define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) {
+; CHECK-LABEL: test_non_const_splat_i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.hs a1, zero, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %insert = insertelement <2 x i16> poison, i16 %elt, i32 0
+  %splat = shufflevector <2 x i16> %insert, <2 x i16> poison, <2 x i32> zeroinitializer
+  store <2 x i16> %splat, ptr %ret_ptr
+  ret void
+}
+
 ; Intrinsic declarations
 declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>)
 declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>)
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 353039e9482e9..f989b025a12dc 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -671,6 +671,20 @@ define void @test_pasubu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
   ret void
 }
 
+; Test for splat
+define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) {
+; CHECK-LABEL: test_non_const_splat_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    padd.ws a1, zero, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %insert = insertelement <2 x i32> poison, i32 %elt, i32 0
+  %splat = shufflevector <2 x i32> %insert, <2 x i32> poison, <2 x i32> zeroinitializer
+  store <2 x i32> %splat, ptr %ret_ptr
+  ret void
+}
+
 ; Intrinsic declarations
 declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>)
 declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)



More information about the llvm-commits mailing list