[llvm] [NVPTX] Improve folding to mad with immediate 1 (PR #93628)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 15:56:31 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/93628

>From 5b9b98a3d9f75ea225cb40eb8d0092d21089a00f Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Wed, 29 May 2024 00:42:25 +0000
Subject: [PATCH 1/3] [NVPTX] Improve folding to mad with immediate 1

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  93 ++++++++++++++++--
 llvm/test/CodeGen/NVPTX/combine-mad.ll      | 101 ++++++++++++++++++++
 2 files changed, 188 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/combine-mad.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1e7477cf9d60e..304d1984edd54 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5614,17 +5614,98 @@ static SDValue TryMULWIDECombine(SDNode *N,
   return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
 }
 
+static SDValue matchMADConstOnePattern(SDValue X, SDValue Add) {
+  if (Add->getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue Y = Add->getOperand(0);
+  ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Add->getOperand(1));
+  if (!Const || Const->getZExtValue() != 1)
+    return SDValue();
+
+  return Y;
+}
+
+static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+
+  if (SDValue Y = matchMADConstOnePattern(X, Add))
+    return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
+
+  return SDValue();
+}
+
+static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
+                                        SDLoc DL,
+                                        TargetLowering::DAGCombinerInfo &DCI) {
+  if (Select->getOpcode() != ISD::SELECT)
+    return SDValue();
+
+  SDValue Cond = Select->getOperand(0);
+
+  unsigned ConstOpNo = 1;
+  auto *Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
+  if (!Const || Const->getZExtValue() != 1) {
+    ConstOpNo = 2;
+    Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
+    if (!Const || Const->getZExtValue() != 1)
+      return SDValue();
+  }
+
+  SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
+
+  // Do not combine if the resulting sequence is not obviously profitable.
+  if (!matchMADConstOnePattern(X, Y))
+    return SDValue();
+
+  SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
+
+  return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
+                         (ConstOpNo == 1) ? X : NewMul,
+                         (ConstOpNo == 1) ? NewMul : X);
+}
+
+static SDValue
+PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                              TargetLowering::DAGCombinerInfo &DCI) {
+
+  EVT VT = N0.getValueType();
+  if (VT.isVector())
+    return SDValue();
+
+  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // (mul x, (add y, 1)) -> (mad x, y, x)
+  if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
+    return Res;
+  if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
+    return Res;
+
+  // (mul x, (select y, 1)) -> (select (mul x, y), x)
+  if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
+    return Res;
+  if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
+    return Res;
+
+  return SDValue();
+}
+
 /// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
 static SDValue PerformMULCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
-  if (OptLevel > CodeGenOptLevel::None) {
-    // Try mul.wide combining at OptLevel > 0
-    if (SDValue Ret = TryMULWIDECombine(N, DCI))
-      return Ret;
-  }
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
 
-  return SDValue();
+  if (SDValue Ret = TryMULWIDECombine(N, DCI))
+    return Ret;
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  return PerformMULCombineWithOperands(N, N0, N1, DCI);
 }
 
 /// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
new file mode 100644
index 0000000000000..382856dfe76c3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -0,0 +1,101 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}
+
+define i32 @test1(i32 %n, i32 %m) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test1_param_0];
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test1_param_1];
+; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
+; CHECK: st.param.b32   [func_retval0+0], %[[MAD]];
+;
+  %add = add i32 %n, 1
+  %mul = mul i32 %add, %m
+  ret i32 %mul
+}
+
+define i32 @test1_rev(i32 %n, i32 %m) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test1_rev_param_0];
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test1_rev_param_1];
+; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
+; CHECK: st.param.b32   [func_retval0+0], %[[MAD]];
+;
+  %add = add i32 %n, 1
+  %mul = mul i32 %m, %add
+  ret i32 %mul
+}
+
+; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_param_0];
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_param_1];
+; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_param_2];
+; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
+; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
+; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[M]], %[[MAD]], %[[COND]];
+; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
+;
+  %add = add i32 %n, 1
+  %cond = icmp slt i32 %s, 1
+  %sel = select i1 %cond, i32 1, i32 %add
+  %mul = mul i32 %sel, %m
+  ret i32 %mul
+}
+
+;; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_rev1_param_0];
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_rev1_param_1];
+; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_rev1_param_2];
+; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
+; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
+; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
+; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
+;
+  %add = add i32 %n, 1
+  %cond = icmp slt i32 %s, 1
+  %sel = select i1 %cond, i32 %add, i32 1
+  %mul = mul i32 %sel, %m
+  ret i32 %mul
+}
+
+;; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_rev2_param_0];
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_rev2_param_1];
+; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_rev2_param_2];
+; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
+; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
+; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
+; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
+;
+  %add = add i32 %n, 1
+  %cond = icmp slt i32 %s, 1
+  %sel = select i1 %cond, i32 %add, i32 1
+  %mul = mul i32  %m, %sel
+  ret i32 %mul
+}
+
+;; Leave (mul (select)) intact if it transposing is not profitable
+define i32 @test3(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test3_param_0];
+; CHECK: add.s32        %[[ADD:r[0-9]+]], %[[N]], 3;
+; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test3_param_1];
+; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test3_param_2];
+; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
+; CHECK: selp.b32       %[[SEL:r[0-9]+]], 1, %[[ADD]], %[[COND]];
+; CHECK: mul.lo.s32     %[[MUL:r[0-9]+]], %[[SEL]], %[[M]];
+; CHECK: st.param.b32   [func_retval0+0], %[[MUL]];
+;
+  %add = add i32 %n, 3
+  %cond = icmp slt i32 %s, 1
+  %sel = select i1 %cond, i32 1, i32 %add
+  %mul = mul i32 %sel, %m
+  ret i32 %mul
+}

>From af084ad72fded7e83f2de75c5dbc2c7bb1d79a73 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Wed, 29 May 2024 15:40:53 +0000
Subject: [PATCH 2/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  19 +--
 llvm/test/CodeGen/NVPTX/combine-mad.ll      | 129 +++++++++++++-------
 2 files changed, 93 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 304d1984edd54..30eb742658ff5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5614,22 +5614,25 @@ static SDValue TryMULWIDECombine(SDNode *N,
   return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
 }
 
-static SDValue matchMADConstOnePattern(SDValue X, SDValue Add) {
+static SDValue matchMADConstOnePattern(SDValue Add) {
   if (Add->getOpcode() != ISD::ADD)
     return SDValue();
 
-  SDValue Y = Add->getOperand(0);
-  ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Add->getOperand(1));
-  if (!Const || Const->getZExtValue() != 1)
-    return SDValue();
+  if (const auto *Const0 = dyn_cast<ConstantSDNode>(Add->getOperand(0)))
+    if (Const0->getZExtValue() == 1)
+      return Add->getOperand(1);
+
+  if (const auto *Const1 = dyn_cast<ConstantSDNode>(Add->getOperand(1)))
+    if (Const1->getZExtValue() == 1)
+      return Add->getOperand(0);
 
-  return Y;
+  return SDValue();
 }
 
 static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
                                   TargetLowering::DAGCombinerInfo &DCI) {
 
-  if (SDValue Y = matchMADConstOnePattern(X, Add))
+  if (SDValue Y = matchMADConstOnePattern(Add))
     return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
 
   return SDValue();
@@ -5655,7 +5658,7 @@ static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
   SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
 
   // Do not combine if the resulting sequence is not obviously profitable.
-  if (!matchMADConstOnePattern(X, Y))
+  if (!matchMADConstOnePattern(Y))
     return SDValue();
 
   SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 382856dfe76c3..fba389afdca39 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -1,15 +1,21 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_20 -O1 | FileCheck %s
-; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
-; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
-; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}
 
 define i32 @test1(i32 %n, i32 %m) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test1_param_0];
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test1_param_1];
-; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
-; CHECK: st.param.b32   [func_retval0+0], %[[MAD]];
-;
+; CHECK-LABEL: test1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test1_param_1];
+; CHECK-NEXT:    mad.lo.s32 %r3, %r2, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 1
   %mul = mul i32 %add, %m
   ret i32 %mul
@@ -17,11 +23,16 @@ define i32 @test1(i32 %n, i32 %m) {
 
 define i32 @test1_rev(i32 %n, i32 %m) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test1_rev_param_0];
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test1_rev_param_1];
-; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
-; CHECK: st.param.b32   [func_retval0+0], %[[MAD]];
-;
+; CHECK-LABEL: test1_rev(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test1_rev_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test1_rev_param_1];
+; CHECK-NEXT:    mad.lo.s32 %r3, %r2, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 1
   %mul = mul i32 %m, %add
   ret i32 %mul
@@ -30,14 +41,20 @@ define i32 @test1_rev(i32 %n, i32 %m) {
 ; Transpose (mul (select)) if it can then be folded to mad
 define i32 @test2(i32 %n, i32 %m, i32 %s) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_param_0];
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_param_1];
-; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_param_2];
-; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
-; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
-; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[M]], %[[MAD]], %[[COND]];
-; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
-;
+; CHECK-LABEL: test2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test2_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test2_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test2_param_2];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT:    mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT:    selp.b32 %r5, %r2, %r4, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 1
   %cond = icmp slt i32 %s, 1
   %sel = select i1 %cond, i32 1, i32 %add
@@ -48,14 +65,20 @@ define i32 @test2(i32 %n, i32 %m, i32 %s) {
 ;; Transpose (mul (select)) if it can then be folded to mad
 define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_rev1_param_0];
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_rev1_param_1];
-; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_rev1_param_2];
-; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
-; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
-; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
-; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
-;
+; CHECK-LABEL: test2_rev1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test2_rev1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test2_rev1_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test2_rev1_param_2];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT:    mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT:    selp.b32 %r5, %r4, %r2, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 1
   %cond = icmp slt i32 %s, 1
   %sel = select i1 %cond, i32 %add, i32 1
@@ -66,14 +89,20 @@ define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
 ;; Transpose (mul (select)) if it can then be folded to mad
 define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test2_rev2_param_0];
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test2_rev2_param_1];
-; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test2_rev2_param_2];
-; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
-; CHECK: mad.lo.s32     %[[MAD:r[0-9]+]], %[[M]], %[[N]], %[[M]];
-; CHECK: selp.b32       %[[SEL:r[0-9]+]], %[[MAD]], %[[M]], %[[COND]];
-; CHECK: st.param.b32   [func_retval0+0], %[[SEL]];
-;
+; CHECK-LABEL: test2_rev2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test2_rev2_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test2_rev2_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test2_rev2_param_2];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT:    mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT:    selp.b32 %r5, %r4, %r2, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 1
   %cond = icmp slt i32 %s, 1
   %sel = select i1 %cond, i32 %add, i32 1
@@ -84,15 +113,21 @@ define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
 ;; Leave (mul (select)) intact if it transposing is not profitable
 define i32 @test3(i32 %n, i32 %m, i32 %s) {
 ;
-; CHECK: ld.param.u32   %[[N:r[0-9]+]], [test3_param_0];
-; CHECK: add.s32        %[[ADD:r[0-9]+]], %[[N]], 3;
-; CHECK: ld.param.u32   %[[M:r[0-9]+]], [test3_param_1];
-; CHECK: ld.param.u32   %[[S:r[0-9]+]], [test3_param_2];
-; CHECK: setp.lt.s32    %[[COND:p[0-9]+]], %[[S]], 1;
-; CHECK: selp.b32       %[[SEL:r[0-9]+]], 1, %[[ADD]], %[[COND]];
-; CHECK: mul.lo.s32     %[[MUL:r[0-9]+]], %[[SEL]], %[[M]];
-; CHECK: st.param.b32   [func_retval0+0], %[[MUL]];
-;
+; CHECK-LABEL: test3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test3_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 3;
+; CHECK-NEXT:    ld.param.u32 %r3, [test3_param_1];
+; CHECK-NEXT:    ld.param.u32 %r4, [test3_param_2];
+; CHECK-NEXT:    setp.lt.s32 %p1, %r4, 1;
+; CHECK-NEXT:    selp.b32 %r5, 1, %r2, %p1;
+; CHECK-NEXT:    mul.lo.s32 %r6, %r5, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r6;
+; CHECK-NEXT:    ret;
   %add = add i32 %n, 3
   %cond = icmp slt i32 %s, 1
   %sel = select i1 %cond, i32 1, i32 %add

>From d0c2760a85eca13d89365d7c2af9862a2aaecdc5 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Wed, 29 May 2024 22:56:03 +0000
Subject: [PATCH 3/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 28 +++++++++++----------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 30eb742658ff5..f4ef7c9914f13 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5614,17 +5614,20 @@ static SDValue TryMULWIDECombine(SDNode *N,
   return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
 }
 
+static bool isConstOne(const SDValue &Operand) {
+  const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+  return Const && Const->getZExtValue() == 1;
+}
+
 static SDValue matchMADConstOnePattern(SDValue Add) {
   if (Add->getOpcode() != ISD::ADD)
     return SDValue();
 
-  if (const auto *Const0 = dyn_cast<ConstantSDNode>(Add->getOperand(0)))
-    if (Const0->getZExtValue() == 1)
-      return Add->getOperand(1);
+  if (isConstOne(Add->getOperand(0)))
+    return Add->getOperand(1);
 
-  if (const auto *Const1 = dyn_cast<ConstantSDNode>(Add->getOperand(1)))
-    if (Const1->getZExtValue() == 1)
-      return Add->getOperand(0);
+  if (isConstOne(Add->getOperand(1)))
+    return Add->getOperand(0);
 
   return SDValue();
 }
@@ -5646,14 +5649,13 @@ static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
 
   SDValue Cond = Select->getOperand(0);
 
-  unsigned ConstOpNo = 1;
-  auto *Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
-  if (!Const || Const->getZExtValue() != 1) {
+  unsigned ConstOpNo;
+  if (isConstOne(Select->getOperand(1)))
+    ConstOpNo = 1;
+  else if (isConstOne(Select->getOperand(2)))
     ConstOpNo = 2;
-    Const = dyn_cast<ConstantSDNode>(Select->getOperand(ConstOpNo));
-    if (!Const || Const->getZExtValue() != 1)
-      return SDValue();
-  }
+  else
+    return SDValue();
 
   SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
 



More information about the llvm-commits mailing list