[llvm] [NVPTX] Fix crash caused by ComputePTXValueVTs (PR #104524)
Justin Fargnoli via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 15 16:38:51 PDT 2024
https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/104524
When [lowering return values](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3422) from LLVM IR to SelectionDAG, we check that [the number of values `SelectionDAG` tells us to return is equal to the number of values that `ComputePTXValueVTs()` tells us to return](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3441). However, this check can fail on valid IR. For example:
```
define <6 x half> @foo() {
ret <6 x half> zeroinitializer
}
```
`ComputePTXValueVTs()` tells us to return ***3*** `v2f16` values, while `SelectionDAG` tells us to return ***6*** `f16` values.
`ComputePTXValueVTs()` [supports all `half` element vectors with an even number of elements](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L213). Whereas `SelectionDAG` [only supports power-of-2 sized vectors](https://github.com/llvm/llvm-project/blob/4e078e3797098daa40d254447c499bcf61415308/llvm/lib/CodeGen/TargetLoweringBase.cpp#L1580).
Assuming that the developers who added the code to `ComputePTXValueVTs()` overlooked this, I've restricted `ComputePTXValueVTs()` to compute the same number of return values as `SelectionDAG` instead of extending `SelectionDAG` to support non-power-of-2 vector sizes.
>From 0c7cdb6aaec65c9c2e75cef7bd10988a4813ed81 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:27:46 -0700
Subject: [PATCH 1/3] [NVPTX] Fix crash caused by diff between
ComputePTXValueVTs and SelectionDAG
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 10 ++---
.../CodeGen/NVPTX/compute-ptx-value-vts.ll | 37 +++++++++++++++++++
2 files changed, 42 insertions(+), 5 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43a3fbf4d1306a..81d5fcde8e4617 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,10 +207,10 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
if (VT.isVector()) {
unsigned NumElts = VT.getVectorNumElements();
EVT EltVT = VT.getVectorElementType();
- // Vectors with an even number of f16 elements will be passed to
- // us as an array of v2f16/v2bf16 elements. We must match this so we
- // stay in sync with Ins/Outs.
- if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
+ if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+ // Vectors with an even number of f16 elements will be passed to
+ // us as an array of v2f16/v2bf16 elements. We must match this so we
+ // stay in sync with Ins/Outs.
switch (EltVT.getSimpleVT().SimpleTy) {
case MVT::f16:
EltVT = MVT::v2f16;
@@ -226,7 +226,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
- (NumElts % 4 == 0 || NumElts == 3)) {
+ ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
new file mode 100644
index 00000000000000..08960ee2aad373
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -0,0 +1,37 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+define <6 x half> @half6() {
+ ret <6 x half> zeroinitializer
+}
+
+define <10 x half> @half10() {
+ ret <10 x half> zeroinitializer
+}
+
+define <14 x half> @half14() {
+ ret <14 x half> zeroinitializer
+}
+
+define <18 x half> @half18() {
+ ret <18 x half> zeroinitializer
+}
+
+define <12 x i8> @byte12() {
+ ret <12 x i8> zeroinitializer
+}
+
+define <20 x i8> @byte20() {
+ ret <20 x i8> zeroinitializer
+}
+
+define <24 x i8> @byte24() {
+ ret <24 x i8> zeroinitializer
+}
+
+define <28 x i8> @byte28() {
+ ret <28 x i8> zeroinitializer
+}
+
+define <36 x i8> @byte36() {
+ ret <36 x i8> zeroinitializer
+}
>From 60a9e5507843bc268174c7c294639c8e683a7cf7 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:28:07 -0700
Subject: [PATCH 2/3] clang-format
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 81d5fcde8e4617..878c792a9b06ca 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,7 +207,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
if (VT.isVector()) {
unsigned NumElts = VT.getVectorNumElements();
EVT EltVT = VT.getVectorElementType();
- if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+ if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
+ isPowerOf2_32(NumElts)) {
// Vectors with an even number of f16 elements will be passed to
// us as an array of v2f16/v2bf16 elements. We must match this so we
// stay in sync with Ins/Outs.
@@ -226,7 +227,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
- ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
+ ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
+ NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
>From c670aa9f43acd7ddd83fcf4612c19470cb16bca9 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:37:41 -0700
Subject: [PATCH 3/3] Add large test to prevent special casing
---
llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
index 08960ee2aad373..4c56b5fb5a34c3 100644
--- a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -16,6 +16,10 @@ define <18 x half> @half18() {
ret <18 x half> zeroinitializer
}
+define <998 x half> @half998() {
+ ret <998 x half> zeroinitializer
+}
+
define <12 x i8> @byte12() {
ret <12 x i8> zeroinitializer
}
@@ -35,3 +39,7 @@ define <28 x i8> @byte28() {
define <36 x i8> @byte36() {
ret <36 x i8> zeroinitializer
}
+
+define <996 x i8> @byte996() {
+ ret <996 x i8> zeroinitializer
+}
More information about the llvm-commits
mailing list