[llvm] [NVPTX] Fix crash caused by ComputePTXValueVTs (PR #104524)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 16:39:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

<details>
<summary>Changes</summary>

When [lowering return values](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3422) from LLVM IR to SelectionDAG, we check that [the number of values `SelectionDAG` tells us to return is equal to the number of values that `ComputePTXValueVTs()` tells us to return](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3441). However, this check can fail on valid IR. For example: 

```
define <6 x half> @<!-- -->foo() {
  ret <6 x half> zeroinitializer
}
```

`ComputePTXValueVTs()` tells us to return ***3*** `v2f16` values, while `SelectionDAG` tells us to return ***6*** `f16` values. 

`ComputePTXValueVTs()` [supports all `half` element vectors with an even number of elements](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L213). Whereas `SelectionDAG` [only supports power-of-2 sized vectors](https://github.com/llvm/llvm-project/blob/4e078e3797098daa40d254447c499bcf61415308/llvm/lib/CodeGen/TargetLoweringBase.cpp#L1580).

Assuming that the developers who added the code to `ComputePTXValueVTs()` overlooked this, I've restricted `ComputePTXValueVTs()` to compute the same number of return values as `SelectionDAG` instead of extending `SelectionDAG` to support non-power-of-2 vector sizes.

---
Full diff: https://github.com/llvm/llvm-project/pull/104524.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+7-5) 
- (added) llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll (+45) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43a3fbf4d1306a..878c792a9b06ca 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,10 +207,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
     if (VT.isVector()) {
       unsigned NumElts = VT.getVectorNumElements();
       EVT EltVT = VT.getVectorElementType();
-      // Vectors with an even number of f16 elements will be passed to
-      // us as an array of v2f16/v2bf16 elements. We must match this so we
-      // stay in sync with Ins/Outs.
-      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
+      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
+          isPowerOf2_32(NumElts)) {
+        // Vectors with an even number of f16 elements will be passed to
+        // us as an array of v2f16/v2bf16 elements. We must match this so we
+        // stay in sync with Ins/Outs.
         switch (EltVT.getSimpleVT().SimpleTy) {
         case MVT::f16:
           EltVT = MVT::v2f16;
@@ -226,7 +227,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
         }
         NumElts /= 2;
       } else if (EltVT.getSimpleVT() == MVT::i8 &&
-                 (NumElts % 4 == 0 || NumElts == 3)) {
+                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
+                  NumElts == 3)) {
         // v*i8 are formally lowered as v4i8
         EltVT = MVT::v4i8;
         NumElts = (NumElts + 3) / 4;
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
new file mode 100644
index 00000000000000..4c56b5fb5a34c3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -0,0 +1,45 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+define <6 x half> @half6() {
+  ret <6 x half> zeroinitializer
+}
+
+define <10 x half> @half10() {
+  ret <10 x half> zeroinitializer
+}
+
+define <14 x half> @half14() {
+  ret <14 x half> zeroinitializer
+}
+
+define <18 x half> @half18() {
+  ret <18 x half> zeroinitializer
+}
+
+define <998 x half> @half998() {
+  ret <998 x half> zeroinitializer
+}
+
+define <12 x i8> @byte12() {
+  ret <12 x i8> zeroinitializer
+}
+
+define <20 x i8> @byte20() {
+  ret <20 x i8> zeroinitializer
+}
+
+define <24 x i8> @byte24() {
+  ret <24 x i8> zeroinitializer
+}
+
+define <28 x i8> @byte28() {
+  ret <28 x i8> zeroinitializer
+}
+
+define <36 x i8> @byte36() {
+  ret <36 x i8> zeroinitializer
+}
+
+define <996 x i8> @byte996() {
+  ret <996 x i8> zeroinitializer
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/104524


More information about the llvm-commits mailing list