[llvm] [NVPTX] Fix crash caused by ComputePTXValueVTs (PR #104524)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 15 16:39:26 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Justin Fargnoli (justinfargnoli)
<details>
<summary>Changes</summary>
When [lowering return values](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3422) from LLVM IR to SelectionDAG, we check that [the number of values `SelectionDAG` tells us to return is equal to the number of values that `ComputePTXValueVTs()` tells us to return](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3441). However, this check can fail on valid IR. For example:
```
define <6 x half> @<!-- -->foo() {
ret <6 x half> zeroinitializer
}
```
`ComputePTXValueVTs()` tells us to return ***3*** `v2f16` values, while `SelectionDAG` tells us to return ***6*** `f16` values.
`ComputePTXValueVTs()` [supports all `half` element vectors with an even number of elements](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L213). Whereas `SelectionDAG` [only supports power-of-2 sized vectors](https://github.com/llvm/llvm-project/blob/4e078e3797098daa40d254447c499bcf61415308/llvm/lib/CodeGen/TargetLoweringBase.cpp#L1580).
Assuming that the developers who added the code to `ComputePTXValueVTs()` overlooked this, I've restricted `ComputePTXValueVTs()` to compute the same number of return values as `SelectionDAG` instead of extending `SelectionDAG` to support non-power-of-2 vector sizes.
---
Full diff: https://github.com/llvm/llvm-project/pull/104524.diff
2 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+7-5)
- (added) llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll (+45)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43a3fbf4d1306a..878c792a9b06ca 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,10 +207,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
if (VT.isVector()) {
unsigned NumElts = VT.getVectorNumElements();
EVT EltVT = VT.getVectorElementType();
- // Vectors with an even number of f16 elements will be passed to
- // us as an array of v2f16/v2bf16 elements. We must match this so we
- // stay in sync with Ins/Outs.
- if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
+ if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
+ isPowerOf2_32(NumElts)) {
+ // Vectors with an even number of f16 elements will be passed to
+ // us as an array of v2f16/v2bf16 elements. We must match this so we
+ // stay in sync with Ins/Outs.
switch (EltVT.getSimpleVT().SimpleTy) {
case MVT::f16:
EltVT = MVT::v2f16;
@@ -226,7 +227,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
NumElts /= 2;
} else if (EltVT.getSimpleVT() == MVT::i8 &&
- (NumElts % 4 == 0 || NumElts == 3)) {
+ ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
+ NumElts == 3)) {
// v*i8 are formally lowered as v4i8
EltVT = MVT::v4i8;
NumElts = (NumElts + 3) / 4;
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
new file mode 100644
index 00000000000000..4c56b5fb5a34c3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -0,0 +1,45 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+define <6 x half> @half6() {
+ ret <6 x half> zeroinitializer
+}
+
+define <10 x half> @half10() {
+ ret <10 x half> zeroinitializer
+}
+
+define <14 x half> @half14() {
+ ret <14 x half> zeroinitializer
+}
+
+define <18 x half> @half18() {
+ ret <18 x half> zeroinitializer
+}
+
+define <998 x half> @half998() {
+ ret <998 x half> zeroinitializer
+}
+
+define <12 x i8> @byte12() {
+ ret <12 x i8> zeroinitializer
+}
+
+define <20 x i8> @byte20() {
+ ret <20 x i8> zeroinitializer
+}
+
+define <24 x i8> @byte24() {
+ ret <24 x i8> zeroinitializer
+}
+
+define <28 x i8> @byte28() {
+ ret <28 x i8> zeroinitializer
+}
+
+define <36 x i8> @byte36() {
+ ret <36 x i8> zeroinitializer
+}
+
+define <996 x i8> @byte996() {
+ ret <996 x i8> zeroinitializer
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/104524
More information about the llvm-commits
mailing list