[llvm] [NVPTX] Fix crash caused by ComputePTXValueVTs (PR #104524)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 12:39:22 PDT 2024


https://github.com/justinfargnoli updated https://github.com/llvm/llvm-project/pull/104524

>From 0c7cdb6aaec65c9c2e75cef7bd10988a4813ed81 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:27:46 -0700
Subject: [PATCH 1/4] [NVPTX] Fix crash caused by diff between
 ComputePTXValueVTs and SelectionDAG

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 10 ++---
 .../CodeGen/NVPTX/compute-ptx-value-vts.ll    | 37 +++++++++++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 43a3fbf4d1306a..81d5fcde8e4617 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,10 +207,10 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
     if (VT.isVector()) {
       unsigned NumElts = VT.getVectorNumElements();
       EVT EltVT = VT.getVectorElementType();
-      // Vectors with an even number of f16 elements will be passed to
-      // us as an array of v2f16/v2bf16 elements. We must match this so we
-      // stay in sync with Ins/Outs.
-      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
+      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+        // Vectors with an even number of f16 elements will be passed to
+        // us as an array of v2f16/v2bf16 elements. We must match this so we
+        // stay in sync with Ins/Outs.
         switch (EltVT.getSimpleVT().SimpleTy) {
         case MVT::f16:
           EltVT = MVT::v2f16;
@@ -226,7 +226,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
         }
         NumElts /= 2;
       } else if (EltVT.getSimpleVT() == MVT::i8 &&
-                 (NumElts % 4 == 0 || NumElts == 3)) {
+                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
         // v*i8 are formally lowered as v4i8
         EltVT = MVT::v4i8;
         NumElts = (NumElts + 3) / 4;
diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
new file mode 100644
index 00000000000000..08960ee2aad373
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -0,0 +1,37 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+define <6 x half> @half6() {
+  ret <6 x half> zeroinitializer
+}
+
+define <10 x half> @half10() {
+  ret <10 x half> zeroinitializer
+}
+
+define <14 x half> @half14() {
+  ret <14 x half> zeroinitializer
+}
+
+define <18 x half> @half18() {
+  ret <18 x half> zeroinitializer
+}
+
+define <12 x i8> @byte12() {
+  ret <12 x i8> zeroinitializer
+}
+
+define <20 x i8> @byte20() {
+  ret <20 x i8> zeroinitializer
+}
+
+define <24 x i8> @byte24() {
+  ret <24 x i8> zeroinitializer
+}
+
+define <28 x i8> @byte28() {
+  ret <28 x i8> zeroinitializer
+}
+
+define <36 x i8> @byte36() {
+  ret <36 x i8> zeroinitializer
+}

>From 60a9e5507843bc268174c7c294639c8e683a7cf7 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:28:07 -0700
Subject: [PATCH 2/4] clang-format

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 81d5fcde8e4617..878c792a9b06ca 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -207,7 +207,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
     if (VT.isVector()) {
       unsigned NumElts = VT.getVectorNumElements();
       EVT EltVT = VT.getVectorElementType();
-      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
+      if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
+          isPowerOf2_32(NumElts)) {
         // Vectors with an even number of f16 elements will be passed to
         // us as an array of v2f16/v2bf16 elements. We must match this so we
         // stay in sync with Ins/Outs.
@@ -226,7 +227,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
         }
         NumElts /= 2;
       } else if (EltVT.getSimpleVT() == MVT::i8 &&
-                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) || NumElts == 3)) {
+                 ((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
+                  NumElts == 3)) {
         // v*i8 are formally lowered as v4i8
         EltVT = MVT::v4i8;
         NumElts = (NumElts + 3) / 4;

>From c670aa9f43acd7ddd83fcf4612c19470cb16bca9 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Thu, 15 Aug 2024 13:37:41 -0700
Subject: [PATCH 3/4] Add large test to prevent special casing

---
 llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
index 08960ee2aad373..4c56b5fb5a34c3 100644
--- a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -16,6 +16,10 @@ define <18 x half> @half18() {
   ret <18 x half> zeroinitializer
 }
 
+define <998 x half> @half998() {
+  ret <998 x half> zeroinitializer
+}
+
 define <12 x i8> @byte12() {
   ret <12 x i8> zeroinitializer
 }
@@ -35,3 +39,7 @@ define <28 x i8> @byte28() {
 define <36 x i8> @byte36() {
   ret <36 x i8> zeroinitializer
 }
+
+define <996 x i8> @byte996() {
+  ret <996 x i8> zeroinitializer
+}

>From 376a3d3e5b470346973411ee01fc36f765794dee Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Fri, 16 Aug 2024 12:39:10 -0700
Subject: [PATCH 4/4] Update test

---
 .../CodeGen/NVPTX/compute-ptx-value-vts.ll    | 74 +++++++++++--------
 1 file changed, 45 insertions(+), 29 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
index 4c56b5fb5a34c3..a88c5637f089b1 100644
--- a/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
+++ b/llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll
@@ -1,45 +1,61 @@
-; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
+
+target triple = "nvptx-nvidia-cuda"
 
 define <6 x half> @half6() {
+; CHECK-LABEL: half6(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b16 %rs1, 0x0000;
+; CHECK-NEXT:    st.param.v4.b16 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v2.b16 [func_retval0+8], {%rs1, %rs1};
+; CHECK-NEXT:    ret;
   ret <6 x half> zeroinitializer
 }
 
 define <10 x half> @half10() {
+; CHECK-LABEL: half10(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b16 %rs1, 0x0000;
+; CHECK-NEXT:    st.param.v4.b16 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b16 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v2.b16 [func_retval0+16], {%rs1, %rs1};
+; CHECK-NEXT:    ret;
   ret <10 x half> zeroinitializer
 }
 
-define <14 x half> @half14() {
-  ret <14 x half> zeroinitializer
-}
-
-define <18 x half> @half18() {
-  ret <18 x half> zeroinitializer
-}
-
-define <998 x half> @half998() {
-  ret <998 x half> zeroinitializer
-}
-
 define <12 x i8> @byte12() {
+; CHECK-LABEL: byte12(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.u16 %rs1, 0;
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+4], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    ret;
   ret <12 x i8> zeroinitializer
 }
 
 define <20 x i8> @byte20() {
+; CHECK-LABEL: byte20(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.u16 %rs1, 0;
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+4], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+12], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    st.param.v4.b8 [func_retval0+16], {%rs1, %rs1, %rs1, %rs1};
+; CHECK-NEXT:    ret;
   ret <20 x i8> zeroinitializer
 }
-
-define <24 x i8> @byte24() {
-  ret <24 x i8> zeroinitializer
-}
-
-define <28 x i8> @byte28() {
-  ret <28 x i8> zeroinitializer
-}
-
-define <36 x i8> @byte36() {
-  ret <36 x i8> zeroinitializer
-}
-
-define <996 x i8> @byte996() {
-  ret <996 x i8> zeroinitializer
-}



More information about the llvm-commits mailing list