[llvm] [RISCV] Handle zeroinitializer of vector tuple Type (PR #113995)
Brandon Wu via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 18 23:48:44 PST 2024
https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/113995
>From 58efc9a37776426bea580b2231b00fb2b668f5b7 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Tue, 29 Oct 2024 06:57:57 +0800
Subject: [PATCH 1/2] [RISCV] Handle zeroinitializer of vector tuple Type
It doesn't make sense to add a new generic ISD to handle riscv tuple
type. Instead we use `SPLAT_VECTOR` for ISD and further lower to `VMV_V_X`.
Note: If there's `visitSPLAT_VECTOR` in generic DAG combiner, it needs
to skip riscv vector tuple type.
---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 3 ++
.../SelectionDAG/SelectionDAGBuilder.cpp | 7 ++++
llvm/lib/IR/Type.cpp | 3 +-
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 19 +++++++++
.../RISCV/vector-tuple-zeroinitializer.ll | 40 +++++++++++++++++++
5 files changed, 71 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..9d9f71aab3fde9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6434,6 +6434,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(ISD::VECREDUCE_AND, DL, VT, N1);
break;
case ISD::SPLAT_VECTOR:
+ // RISC-V vector tuple type is not a vector type.
+ if (VT.isRISCVVectorTuple())
+ break;
assert(VT.isVector() && "Wrong return type!");
// FIXME: Hexagon uses i32 scalar for a floating point zero vector so allow
// that for now.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d729d448502d8..e69c490cb7ef43 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1896,6 +1896,13 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
}
+ if (VT.isRISCVVectorTuple()) {
+ assert(C->isNullValue() && "Can only zero this target type!");
+ return NodeMap[V] = DAG.getNode(
+ ISD::SPLAT_VECTOR, getCurSDLoc(), VT,
+ DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8)));
+ }
+
VectorType *VecTy = cast<VectorType>(V->getType());
// Now that we know the number and type of the elements, get that number of
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 88ede0d35fa3ee..bb1cb077eb9f6d 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -898,7 +898,8 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
RISCV::RVVBitsPerBlock / 8) *
Ty->getIntParameter(0);
return TargetTypeInfo(
- ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts));
+ ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts),
+ TargetExtType::HasZeroInit);
}
// DirectX resources
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index ca368a18c80d64..05cde42770084a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -65,6 +65,25 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL;
SDLoc DL(N);
SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT());
+
+ if (VT.isRISCVVectorTuple()) {
+ unsigned NF = VT.getRISCVVectorTupleNumFields();
+ unsigned NumScalElts = VT.getSizeInBits() / (NF * 8);
+ SDValue EltVal = CurDAG->getConstant(0, DL, Subtarget->getXLenVT());
+ MVT ScalTy =
+ MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts);
+
+ SDValue Splat = CurDAG->getNode(RISCVISD::VMV_V_X_VL, DL, ScalTy,
+ CurDAG->getUNDEF(ScalTy), EltVal, VL);
+
+ Result = CurDAG->getUNDEF(VT);
+ for (unsigned i = 0; i < NF; ++i)
+ Result = CurDAG->getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result,
+ Splat, CurDAG->getVectorIdxConstant(i, DL));
+
+ break;
+ }
+
SDValue Src = N->getOperand(0);
if (VT.isInteger())
Src = CurDAG->getNode(ISD::ANY_EXTEND, DL, Subtarget->getXLenVT(),
diff --git a/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
new file mode 100644
index 00000000000000..88e1315d560b07
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: sed 's/iXLen/i32/g' %s | llc -mtriple=riscv32 -mattr=+v \
+; RUN: -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK
+; RUN: sed 's/iXLen/i64/g' %s | llc -mtriple=riscv64 -mattr=+v \
+; RUN: -verify-machineinstrs | FileCheck %s --check-prefixes=CHECK
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero0() {
+; CHECK-LABEL: test_tuple_zero0:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vmv.v.i v10, 0
+; CHECK-NEXT: ret
+entry:
+ ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer
+}
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero1(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: test_tuple_zero1:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT: vmv.v.i v10, 0
+; CHECK-NEXT: ret
+entry:
+ %1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 0)
+ ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
+}
+
+define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero2(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: test_tuple_zero2:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; CHECK-NEXT: vmv.v.i v6, 0
+; CHECK-NEXT: vmv2r.v v10, v8
+; CHECK-NEXT: vmv2r.v v8, v6
+; CHECK-NEXT: ret
+entry:
+ %1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 1)
+ ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
+}
>From dbe42eb49a8889eee6e6009197e9e0ce6b0bd11f Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Mon, 18 Nov 2024 23:48:03 -0800
Subject: [PATCH 2/2] fixup! [RISCV] Handle zeroinitializer of vector tuple
Type
---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 3 ---
.../SelectionDAG/SelectionDAGBuilder.cpp | 8 +++++--
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 24 +++++++++++--------
3 files changed, 20 insertions(+), 15 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 9d9f71aab3fde9..3a8ec3c6105bc0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6434,9 +6434,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(ISD::VECREDUCE_AND, DL, VT, N1);
break;
case ISD::SPLAT_VECTOR:
- // RISC-V vector tuple type is not a vector type.
- if (VT.isRISCVVectorTuple())
- break;
assert(VT.isVector() && "Wrong return type!");
// FIXME: Hexagon uses i32 scalar for a floating point zero vector so allow
// that for now.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index e69c490cb7ef43..a4e67284bc77f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1899,8 +1899,12 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
if (VT.isRISCVVectorTuple()) {
assert(C->isNullValue() && "Can only zero this target type!");
return NodeMap[V] = DAG.getNode(
- ISD::SPLAT_VECTOR, getCurSDLoc(), VT,
- DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8)));
+ ISD::BITCAST, getCurSDLoc(), VT,
+ DAG.getNode(
+ ISD::SPLAT_VECTOR, getCurSDLoc(),
+ MVT::getScalableVectorVT(
+ MVT::i8, VT.getSizeInBits().getKnownMinValue() / 8),
+ DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8))));
}
VectorType *VecTy = cast<VectorType>(V->getType());
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 05cde42770084a..526126c715d4f6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -57,18 +57,14 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
SDValue Result;
switch (N->getOpcode()) {
- case ISD::SPLAT_VECTOR: {
- // Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
- // SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
+ case ISD::BITCAST: {
MVT VT = N->getSimpleValueType(0);
- unsigned Opc =
- VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL;
SDLoc DL(N);
SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT());
-
- if (VT.isRISCVVectorTuple()) {
+ if (VT.isRISCVVectorTuple() &&
+ N->getOperand(0)->getOpcode() == ISD::SPLAT_VECTOR) {
unsigned NF = VT.getRISCVVectorTupleNumFields();
- unsigned NumScalElts = VT.getSizeInBits() / (NF * 8);
+ unsigned NumScalElts = VT.getSizeInBits().getKnownMinValue() / (NF * 8);
SDValue EltVal = CurDAG->getConstant(0, DL, Subtarget->getXLenVT());
MVT ScalTy =
MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts);
@@ -80,9 +76,17 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
for (unsigned i = 0; i < NF; ++i)
Result = CurDAG->getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result,
Splat, CurDAG->getVectorIdxConstant(i, DL));
-
- break;
}
+ break;
+ }
+ case ISD::SPLAT_VECTOR: {
+ // Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
+ // SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
+ MVT VT = N->getSimpleValueType(0);
+ unsigned Opc =
+ VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL;
+ SDLoc DL(N);
+ SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT());
SDValue Src = N->getOperand(0);
if (VT.isInteger())
More information about the llvm-commits
mailing list