[llvm] [NVPTX] Fix crash caused by ComputePTXValueVTs (PR #104524)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 16:38:51 PDT 2024


https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/104524

When [lowering return values](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3422) from LLVM IR to SelectionDAG, we check that [the number of values `SelectionDAG` tells us to return is equal to the number of values that `ComputePTXValueVTs()` tells us to return](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3441). However, this check can fail on valid IR. For example: 

```
define <6 x half> @foo() {
  ret <6 x half> zeroinitializer
}
```

`ComputePTXValueVTs()` tells us to return ***3*** `v2f16` values, while `SelectionDAG` tells us to return ***6*** `f16` values. 

`ComputePTXValueVTs()` [supports all `half` element vectors with an even number of elements](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L213). Whereas `SelectionDAG` [only supports power-of-2 sized vectors](https://github.com/llvm/llvm-project/blob/4e078e3797098daa40d254447c499bcf61415308/llvm/lib/CodeGen/TargetLoweringBase.cpp#L1580).

Assuming that the developers who added the code to `ComputePTXValueVTs()` overlooked this, I've restricted `ComputePTXValueVTs()` to compute the same number of return values as `SelectionDAG` instead of extending `SelectionDAG` to support non-power-of-2 vector sizes.

>From 0c7cdb6aaec65c9c2e75cef7bd10988a4813ed81 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:27:46 -0700
Subject: [PATCH 1/3] [NVPTX] Fix crash caused by diff between
 ComputePTXValueVTs and SelectionDAG

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 10 ++---
 .../CodeGen/NVPTX/compute-ptx-value-vts.ll    | 37 +++++++++++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43a3fbf4d1306a..81d5fcde8e4617 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,10 +207,10 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
     if (VT.isVector()) {
       unsigned NumElts = VT.getVectorNumElements();
       EVT EltVT = VT.getVectorElementType();
-      // Vectors with an even number of f16 elements will be passed to
-      // us as an array of v2f16/v2bf16 elements. We must match this so we
-      // stay in sync with Ins/Outs.
-      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
+      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+        // Vectors with an even number of f16 elements will be passed to
+        // us as an array of v2f16/v2bf16 elements. We must match this so we
+        // stay in sync with Ins/Outs.
         switch (EltVT.getSimpleVT().SimpleTy) {
         case MVT::f16:
           EltVT = MVT::v2f16;
@@ -226,7 +226,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
         }
         NumElts /= 2;
       } else if (EltVT.getSimpleVT() == MVT::i8 &&
-                 (NumElts % 4 == 0 || NumElts == 3)) {
+                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
         // v*i8 are formally lowered as v4i8
         EltVT = MVT::v4i8;
         NumElts = (NumElts + 3) / 4;
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
new file mode 100644
index 00000000000000..08960ee2aad373
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -0,0 +1,37 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+define <6 x half> @half6() {
+  ret <6 x half> zeroinitializer
+}
+
+define <10 x half> @half10() {
+  ret <10 x half> zeroinitializer
+}
+
+define <14 x half> @half14() {
+  ret <14 x half> zeroinitializer
+}
+
+define <18 x half> @half18() {
+  ret <18 x half> zeroinitializer
+}
+
+define <12 x i8> @byte12() {
+  ret <12 x i8> zeroinitializer
+}
+
+define <20 x i8> @byte20() {
+  ret <20 x i8> zeroinitializer
+}
+
+define <24 x i8> @byte24() {
+  ret <24 x i8> zeroinitializer
+}
+
+define <28 x i8> @byte28() {
+  ret <28 x i8> zeroinitializer
+}
+
+define <36 x i8> @byte36() {
+  ret <36 x i8> zeroinitializer
+}

>From 60a9e5507843bc268174c7c294639c8e683a7cf7 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:28:07 -0700
Subject: [PATCH 2/3] clang-format

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 81d5fcde8e4617..878c792a9b06ca 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,7 +207,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
     if (VT.isVector()) {
       unsigned NumElts = VT.getVectorNumElements();
       EVT EltVT = VT.getVectorElementType();
-      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
+          isPowerOf2_32(NumElts)) {
         // Vectors with an even number of f16 elements will be passed to
         // us as an array of v2f16/v2bf16 elements. We must match this so we
         // stay in sync with Ins/Outs.
@@ -226,7 +227,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
         }
         NumElts /= 2;
       } else if (EltVT.getSimpleVT() == MVT::i8 &&
-                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
+                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
+                  NumElts == 3)) {
         // v*i8 are formally lowered as v4i8
         EltVT = MVT::v4i8;
         NumElts = (NumElts + 3) / 4;

>From c670aa9f43acd7ddd83fcf4612c19470cb16bca9 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:37:41 -0700
Subject: [PATCH 3/3] Add large test to prevent special casing

---
 llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
index 08960ee2aad373..4c56b5fb5a34c3 100644
--- a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -16,6 +16,10 @@ define <18 x half> @half18() {
   ret <18 x half> zeroinitializer
 }
 
+define <998 x half> @half998() {
+  ret <998 x half> zeroinitializer
+}
+
 define <12 x i8> @byte12() {
   ret <12 x i8> zeroinitializer
 }
@@ -35,3 +39,7 @@ define <28 x i8> @byte28() {
 define <36 x i8> @byte36() {
   ret <36 x i8> zeroinitializer
 }
+
+define <996 x i8> @byte996() {
+  ret <996 x i8> zeroinitializer
+}



More information about the llvm-commits mailing list