[llvm] [NVPTX] Implement isTruncateFree and isZExtFree for i32/i64 Optimizations (PR #114683)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 00:19:27 PST 2024


https://github.com/Quark-69 updated https://github.com/llvm/llvm-project/pull/114683

>From f43171b27dd59a4d71de4a891cad4006d7fc1a6c Mon Sep 17 00:00:00 2001
From: ujjawalk <ujjawal.kchouhan.ece23 at itbhu.ac.in>
Date: Sat, 2 Nov 2024 15:07:34 +0530
Subject: [PATCH 1/3] Implemented isZextFree and IsTruncateFree in NVPTX target
 lowering.

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 23 +++++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   |  4 ++++
 llvm/test/CodeGen/NVPTX/truncate_zext.ll    | 17 +++++++++++++++
 3 files changed, 44 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/truncate_zext.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3bf0ecfe2cc92..b5fc975d71dfa8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3340,6 +3340,29 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
   return false;
 }
 
+bool llvm::NVPTXTargetLowering::isTruncateFree(EVT FromVT, EVT ToVT) const {
+
+  if (!FromVT.isSimple() || !ToVT.isSimple()) {
+    return false;
+  }
+
+  return (FromVT.getSimpleVT() == MVT::i64 && ToVT.getSimpleVT() == MVT::i32);
+}
+
+bool llvm::NVPTXTargetLowering::isZExtFree(EVT FromVT, EVT ToVT) const {
+  if (!FromVT.isSimple() || !ToVT.isSimple()) {
+    return false;
+  }
+  return (FromVT.getSimpleVT() == MVT::i32 && ToVT.getSimpleVT() == MVT::i64);
+}
+
+bool llvm::NVPTXTargetLowering::isZExtFree(Type *SrcTy, Type *DstTy) const {
+  if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
+    return false;
+  return SrcTy->getPrimitiveSizeInBits() == 32 &&
+         DstTy->getPrimitiveSizeInBits() == 64;
+}
+
 // This creates target external symbol for a function parameter.
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index c8b589ae39413e..fa73938a35a168 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -616,6 +616,10 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
+  bool isTruncateFree(EVT FromVT, EVT ToVT) const override;
+  bool isZExtFree(EVT FromVT, EVT ToVT) const override;
+  bool isZExtFree(Type *SrcTy, Type *DstTy) const override;
+
 private:
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
diff --git a/llvm/test/CodeGen/NVPTX/truncate_zext.ll b/llvm/test/CodeGen/NVPTX/truncate_zext.ll
new file mode 100644
index 00000000000000..decc02c5840491
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/truncate_zext.ll
@@ -0,0 +1,17 @@
+; RUN: llc -march=nvptx64 < %s | FileCheck %s
+
+; Test for truncation from i64 to i32
+define i32 @test_trunc_i64_to_i32(i64 %val) {
+  ; CHECK-LABEL: test_trunc_i64_to_i32
+  ; CHECK: trunc
+  %trunc = trunc i64 %val to i32
+  ret i32 %trunc
+}
+
+; Test for zero-extension from i32 to i64
+define i64 @test_zext_i32_to_i64(i32 %val) {
+  ; CHECK-LABEL: test_zext_i32_to_i64
+  ; CHECK: zext
+  %zext = zext i32 %val to i64
+  ret i64 %zext
+}
\ No newline at end of file

>From 401f834f6ecfdf3b335181b783a1575c1697cf8b Mon Sep 17 00:00:00 2001
From: ujjawalk <ujjawal.kchouhan.ece23 at itbhu.ac.in>
Date: Sun, 3 Nov 2024 02:11:56 +0530
Subject: [PATCH 2/3] fixed the implementations of istruncatefree and
 iszextfree according to nvptx arch.

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 16 +++++-----------
 1 file changed, 5 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b5fc975d71dfa8..d208caebbd1151 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3341,26 +3341,20 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
 }
 
 bool llvm::NVPTXTargetLowering::isTruncateFree(EVT FromVT, EVT ToVT) const {
-
-  if (!FromVT.isSimple() || !ToVT.isSimple()) {
+  if (FromVT.isVector() || ToVT.isVector() || !FromVT.isInteger() ||
+      !ToVT.isInteger()) {
     return false;
   }
 
-  return (FromVT.getSimpleVT() == MVT::i64 && ToVT.getSimpleVT() == MVT::i32);
+  return FromVT.getSizeInBits() == 64 && ToVT.getSizeInBits() == 32;
 }
 
 bool llvm::NVPTXTargetLowering::isZExtFree(EVT FromVT, EVT ToVT) const {
-  if (!FromVT.isSimple() || !ToVT.isSimple()) {
-    return false;
-  }
-  return (FromVT.getSimpleVT() == MVT::i32 && ToVT.getSimpleVT() == MVT::i64);
+  return false;
 }
 
 bool llvm::NVPTXTargetLowering::isZExtFree(Type *SrcTy, Type *DstTy) const {
-  if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
-    return false;
-  return SrcTy->getPrimitiveSizeInBits() == 32 &&
-         DstTy->getPrimitiveSizeInBits() == 64;
+  return false;
 }
 
 // This creates target external symbol for a function parameter.

>From 65500444989f9b2accd8b0e140fe58bc6cec2c00 Mon Sep 17 00:00:00 2001
From: ujjawalk <ujjawal.kchouhan.ece23 at itbhu.ac.in>
Date: Wed, 6 Nov 2024 13:47:33 +0530
Subject: [PATCH 3/3] Implemented the changes based on the feedback
 accordingly.

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 17 --------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   | 15 +++++--
 llvm/test/CodeGen/NVPTX/truncate_zext.ll    | 47 ++++++++++++++++++---
 3 files changed, 53 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1233add1d0374a..18b05b23da220b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3340,23 +3340,6 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
   return false;
 }
 
-bool llvm::NVPTXTargetLowering::isTruncateFree(EVT FromVT, EVT ToVT) const {
-  if (FromVT.isVector() || ToVT.isVector() || !FromVT.isInteger() ||
-      !ToVT.isInteger()) {
-    return false;
-  }
-
-  return FromVT.getSizeInBits() == 64 && ToVT.getSizeInBits() == 32;
-}
-
-bool llvm::NVPTXTargetLowering::isZExtFree(EVT FromVT, EVT ToVT) const {
-  return false;
-}
-
-bool llvm::NVPTXTargetLowering::isZExtFree(Type *SrcTy, Type *DstTy) const {
-  return false;
-}
-
 // This creates target external symbol for a function parameter.
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index fa73938a35a168..ee7633eb9231a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -506,6 +506,17 @@ class NVPTXTargetLowering : public TargetLowering {
            DstTy->getPrimitiveSizeInBits() == 32;
   }
 
+  bool isTruncateFree(EVT FromVT, EVT ToVT) const override {
+    if (!FromVT.isScalarInteger() || !ToVT.isScalarInteger()) {
+      return false;
+    }
+    return FromVT.getSizeInBits() == 64 && ToVT.getSizeInBits() == 32;
+  }
+
+  bool isZExtFree(EVT FromVT, EVT ToVT) const override { return false; }
+
+  bool isZExtFree(Type *SrcTy, Type *DstTy) const override { return false; }
+
   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx,
                          EVT VT) const override {
     if (VT.isVector())
@@ -616,10 +627,6 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
-  bool isTruncateFree(EVT FromVT, EVT ToVT) const override;
-  bool isZExtFree(EVT FromVT, EVT ToVT) const override;
-  bool isZExtFree(Type *SrcTy, Type *DstTy) const override;
-
 private:
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
diff --git a/llvm/test/CodeGen/NVPTX/truncate_zext.ll b/llvm/test/CodeGen/NVPTX/truncate_zext.ll
index decc02c5840491..15159916b1850b 100644
--- a/llvm/test/CodeGen/NVPTX/truncate_zext.ll
+++ b/llvm/test/CodeGen/NVPTX/truncate_zext.ll
@@ -1,17 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -march=nvptx64 < %s | FileCheck %s
 
 ; Test for truncation from i64 to i32
 define i32 @test_trunc_i64_to_i32(i64 %val) {
-  ; CHECK-LABEL: test_trunc_i64_to_i32
-  ; CHECK: trunc
+; CHECK-LABEL: test_trunc_i64_to_i32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_trunc_i64_to_i32_param_0];
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
   %trunc = trunc i64 %val to i32
   ret i32 %trunc
 }
 
 ; Test for zero-extension from i32 to i64
 define i64 @test_zext_i32_to_i64(i32 %val) {
-  ; CHECK-LABEL: test_zext_i32_to_i64
-  ; CHECK: zext
+; CHECK-LABEL: test_zext_i32_to_i64(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %rd1, [test_zext_i32_to_i64_param_0];
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd1;
+; CHECK-NEXT:    ret;
   %zext = zext i32 %val to i64
   ret i64 %zext
-}
\ No newline at end of file
+}
+
+; Test for operand truncation before select
+define i32 @test_select_truncate_free(i1 %cond, i64 %a, i64 %b) {
+; CHECK-LABEL: test_select_truncate_free(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u8 %rs1, [test_select_truncate_free_param_0];
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT:    ld.param.u32 %r1, [test_select_truncate_free_param_1];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_select_truncate_free_param_2];
+; CHECK-NEXT:    selp.b32 %r3, %r1, %r2, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+
+  %trunc_a = trunc i64 %a to i32
+  %trunc_b = trunc i64 %b to i32
+  %result = select i1 %cond, i32 %trunc_a, i32 %trunc_b
+  ret i32 %result
+}



More information about the llvm-commits mailing list