[llvm] [RISCV] Handle zeroinitializer of vector tuple Type (PR #113995)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 20:32:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

It doesn't make sense to add a new generic ISD to handle riscv tuple
type. Instead we use `SPLAT_VECTOR` for ISD and further lower to `VMV_V_X`.

Note: If there's `visitSPLAT_VECTOR` in generic DAG combiner, it needs
to skip riscv vector tuple type.


---
Full diff: https://github.com/llvm/llvm-project/pull/113995.diff


5 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+7) 
- (modified) llvm/lib/IR/Type.cpp (+2-1) 
- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+19) 
- (added) llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll (+40) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 1a86b3b51234d1..0ba0055f99c094 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6288,6 +6288,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
       return getNode(ISD::VECREDUCE_AND, DL, VT, N1);
     break;
   case ISD::SPLAT_VECTOR:
+    // RISC-V vector tuple type is not a vector type.
+    if (VT.isRISCVVectorTuple())
+      break;
     assert(VT.isVector() && "Wrong return type!");
     // FIXME: Hexagon uses i32 scalar for a floating point zero vector so allow
     // that for now.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 8450553743074c..ca63b289ee20cb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1900,6 +1900,13 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
                          DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
     }
 
+    if (VT.isRISCVVectorTuple()) {
+      assert(C->isNullValue() && "Can only zero this target type!");
+      return NodeMap[V] = DAG.getNode(
+                 ISD::SPLAT_VECTOR, getCurSDLoc(), VT,
+                 DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8)));
+    }
+
     VectorType *VecTy = cast<VectorType>(V->getType());
 
     // Now that we know the number and type of the elements, get that number of
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index e311cde415174a..e5fe85dbd18e39 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -880,7 +880,8 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
                  RISCV::RVVBitsPerBlock / 8) *
         Ty->getIntParameter(0);
     return TargetTypeInfo(
-        ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts));
+        ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts),
+        TargetExtType::HasZeroInit);
   }
 
   // DirectX resources
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index dc3f8254cb4e00..998f1e2188ff0b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -66,6 +66,25 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
           VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL;
       SDLoc DL(N);
       SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT());
+
+      if (VT.isRISCVVectorTuple()) {
+        unsigned NF = VT.getRISCVVectorTupleNumFields();
+        unsigned NumScalElts = VT.getSizeInBits() / (NF * 8);
+        SDValue EltVal = CurDAG->getConstant(0, DL, Subtarget->getXLenVT());
+        MVT ScalTy =
+            MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts);
+
+        SDValue Splat = CurDAG->getNode(RISCVISD::VMV_V_X_VL, DL, ScalTy,
+                                        CurDAG->getUNDEF(ScalTy), EltVal, VL);
+
+        Result = CurDAG->getUNDEF(VT);
+        for (unsigned i = 0; i < NF; ++i)
+          Result = CurDAG->getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result,
+                                   Splat, CurDAG->getVectorIdxConstant(i, DL));
+
+        break;
+      }
+
       SDValue Src = N->getOperand(0);
       if (VT.isInteger())
         Src = CurDAG->getNode(ISD::ANY_EXTEND, DL, Subtarget->getXLenVT(),
diff --git a/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
new file mode 100644
index 00000000000000..88e1315d560b07
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v \
+; RUN:   -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v \
+; RUN:   -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero0() {
+; CHECK-LABEL: test_tuple_zero0:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmv.v.i v10, 0
+; CHECK-NEXT:    ret
+entry:
+  ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer
+}
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero1(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: test_tuple_zero1:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT:    vmv.v.i v10, 0
+; CHECK-NEXT:    ret
+entry:
+  %1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 0)
+  ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
+}
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero2(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: test_tuple_zero2:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT:    vmv.v.i v6, 0
+; CHECK-NEXT:    vmv2r.v v10, v8
+; CHECK-NEXT:    vmv2r.v v8, v6
+; CHECK-NEXT:    ret
+entry:
+  %1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 1)
+  ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/113995


More information about the llvm-commits mailing list