[llvm] [NVPTX] Propagate ISD::TRUNCATE to operands to reduce register pressure (PR #98666)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 10:35:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/98666.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+54-1) 
- (added) llvm/test/CodeGen/NVPTX/combine-truncate.ll (+90) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 476a532db0a37..26729c7adb020 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -725,7 +725,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
                        ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
-                       ISD::VSELECT});
+                       ISD::TRUNCATE, ISD::VSELECT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5541,6 +5541,53 @@ static SDValue PerformREMCombine(SDNode *N,
   return SDValue();
 }
 
+// truncate (logic_op x, y) --> logic_op (truncate x), (truncate y)
+// This will reduce register pressure.
+static SDValue PerformTruncCombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI) {
+  if (!DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  SDValue LogicalOp = N->getOperand(0);
+  switch (LogicalOp.getOpcode()) {
+  default:
+    break;
+  case ISD::ADD:
+  case ISD::SUB:
+  case ISD::MUL:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR: {
+    EVT VT = N->getValueType(0);
+    EVT LogicalVT = LogicalOp.getValueType();
+    if (VT != MVT::i32 || LogicalVT != MVT::i64)
+      break;
+    const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo();
+    if (!VT.isScalarInteger() &&
+        !TLI.isOperationLegal(LogicalOp.getOpcode(), VT))
+      break;
+    if (!all_of(LogicalOp.getNode()->uses(), [](SDNode *U) {
+          return U->isMachineOpcode()
+                     ? U->getMachineOpcode() == NVPTX::CVT_u32_u64
+                     : U->getOpcode() == ISD::TRUNCATE;
+        }))
+      break;
+
+    SDLoc DL(N);
+    SDValue CVTNone =
+        DCI.DAG.getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32);
+    SDNode *NarrowL = DCI.DAG.getMachineNode(NVPTX::CVT_u32_u64, DL, VT,
+                                             LogicalOp.getOperand(0), CVTNone);
+    SDNode *NarrowR = DCI.DAG.getMachineNode(NVPTX::CVT_u32_u64, DL, VT,
+                                             LogicalOp.getOperand(1), CVTNone);
+    return DCI.DAG.getNode(LogicalOp.getOpcode(), DL, VT, SDValue(NarrowL, 0),
+                           SDValue(NarrowR, 0));
+  }
+  }
+
+  return SDValue();
+}
+
 enum OperandSignedness {
   Signed = 0,
   Unsigned,
@@ -5957,6 +6004,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case ISD::UREM:
     case ISD::SREM:
       return PerformREMCombine(N, DCI, OptLevel);
+    case ISD::TRUNCATE:
+      return PerformTruncCombine(N, DCI);
     case ISD::SETCC:
       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
     case ISD::LOAD:
@@ -5974,6 +6023,10 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case ISD::VSELECT:
       return PerformVSELECTCombine(N, DCI);
   }
+
+  if (N->isMachineOpcode() && N->getMachineOpcode() == NVPTX::CVT_u32_u64)
+    return PerformTruncCombine(N, DCI);
+
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/combine-truncate.ll b/llvm/test/CodeGen/NVPTX/combine-truncate.ll
new file mode 100644
index 0000000000000..30e415ebe9527
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-truncate.ll
@@ -0,0 +1,90 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i32 @trunc(i64 %a, i64 %b) {
+; CHECK-LABEL: trunc(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [trunc_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd2, [trunc_param_1];
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd1;
+; CHECK-NEXT:    or.b32 %r3, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT:    ret;
+  %or = or i64 %a, %b
+  %trunc = trunc i64 %or to i32
+  ret i32 %trunc
+}
+
+define i32 @trunc_not(i64 %a, i64 %b) {
+; CHECK-LABEL: trunc_not(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [trunc_not_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd2, [trunc_not_param_1];
+; CHECK-NEXT:    or.b64 %rd3, %rd1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd3;
+; CHECK-NEXT:    mov.u64 %rd4, 0;
+; CHECK-NEXT:    st.u64 [%rd4], %rd3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r1;
+; CHECK-NEXT:    ret;
+  %or = or i64 %a, %b
+  %trunc = trunc i64 %or to i32
+  store i64 %or, ptr null
+  ret i32 %trunc
+}
+
+define i32 @trunc_cvt(i64 %a, i64 %b) {
+; CHECK-LABEL: trunc_cvt(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [trunc_cvt_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd2, [trunc_cvt_param_1];
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd2;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd1;
+; CHECK-NEXT:    add.s32 %r3, %r2, %r1;
+; CHECK-NEXT:    or.b32 %r4, %r3, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r4;
+; CHECK-NEXT:    ret;
+  %add = add i64 %a, %b
+  %or = or i64 %add, %a
+  %trunc = trunc i64 %or to i32
+  ret i32 %trunc
+}
+
+define i32 @trunc_cvt_not(i64 %a, i64 %b) {
+; CHECK-LABEL: trunc_cvt_not(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [trunc_cvt_not_param_0];
+; CHECK-NEXT:    ld.param.u64 %rd2, [trunc_cvt_not_param_1];
+; CHECK-NEXT:    add.s64 %rd3, %rd1, %rd2;
+; CHECK-NEXT:    mov.u64 %rd4, 0;
+; CHECK-NEXT:    st.u64 [%rd4], %rd3;
+; CHECK-NEXT:    cvt.u32.u64 %r1, %rd3;
+; CHECK-NEXT:    cvt.u32.u64 %r2, %rd1;
+; CHECK-NEXT:    or.b32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT:    ret;
+  %add = add i64 %a, %b
+  store i64 %add, ptr null
+  %or = or i64 %add, %a
+  %trunc = trunc i64 %or to i32
+  ret i32 %trunc
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/98666


More information about the llvm-commits mailing list