[llvm] f32ebab - [NVPTX] Improve folding to mad with immediate 1 (#93628)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 29 18:09:25 PDT 2024
Author: Alex MacLean
Date: 2024-05-29T18:09:21-07:00
New Revision: f32ebabc27655a1bd26ccdede1610d8d1a05315f
URL: https://github.com/llvm/llvm-project/commit/f32ebabc27655a1bd26ccdede1610d8d1a05315f
DIFF: https://github.com/llvm/llvm-project/commit/f32ebabc27655a1bd26ccdede1610d8d1a05315f.diff
LOG: [NVPTX] Improve folding to mad with immediate 1 (#93628)
Extend NVPTX DAG combining logic to distribute a mul instruction across
an add of 1 into a mad where possible. In addition, add support for
transposing a mul through a select with an option of 1, if that would
allow further mul folding.
Added:
llvm/test/CodeGen/NVPTX/combine-mad.ll
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1e7477cf9d60e..f4ef7c9914f13 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5614,17 +5614,103 @@ static SDValue TryMULWIDECombine(SDNode *N,
return DCI.DAG.getNode(Opc, DL, MulType, TruncLHS, TruncRHS);
}
+static bool isConstOne(const SDValue &Operand) {
+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+ return Const && Const->getZExtValue() == 1;
+}
+
+static SDValue matchMADConstOnePattern(SDValue Add) {
+ if (Add->getOpcode() != ISD::ADD)
+ return SDValue();
+
+ if (isConstOne(Add->getOperand(0)))
+ return Add->getOperand(1);
+
+ if (isConstOne(Add->getOperand(1)))
+ return Add->getOperand(0);
+
+ return SDValue();
+}
+
+static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
+ TargetLowering::DAGCombinerInfo &DCI) {
+
+ if (SDValue Y = matchMADConstOnePattern(Add))
+ return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
+
+ return SDValue();
+}
+
+static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
+ SDLoc DL,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ if (Select->getOpcode() != ISD::SELECT)
+ return SDValue();
+
+ SDValue Cond = Select->getOperand(0);
+
+ unsigned ConstOpNo;
+ if (isConstOne(Select->getOperand(1)))
+ ConstOpNo = 1;
+ else if (isConstOne(Select->getOperand(2)))
+ ConstOpNo = 2;
+ else
+ return SDValue();
+
+ SDValue Y = Select->getOperand((ConstOpNo == 1) ? 2 : 1);
+
+ // Do not combine if the resulting sequence is not obviously profitable.
+ if (!matchMADConstOnePattern(Y))
+ return SDValue();
+
+ SDValue NewMul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
+
+ return DCI.DAG.getNode(ISD::SELECT, DL, VT, Cond,
+ (ConstOpNo == 1) ? X : NewMul,
+ (ConstOpNo == 1) ? NewMul : X);
+}
+
+static SDValue
+PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+ TargetLowering::DAGCombinerInfo &DCI) {
+
+ EVT VT = N0.getValueType();
+ if (VT.isVector())
+ return SDValue();
+
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // (mul x, (add y, 1)) -> (mad x, y, x)
+ if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
+ return Res;
+ if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))
+ return Res;
+
+ // (mul x, (select y, 1)) -> (select (mul x, y), x)
+ if (SDValue Res = combineMulSelectConstOne(N0, N1, VT, DL, DCI))
+ return Res;
+ if (SDValue Res = combineMulSelectConstOne(N1, N0, VT, DL, DCI))
+ return Res;
+
+ return SDValue();
+}
+
/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
static SDValue PerformMULCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
- if (OptLevel > CodeGenOptLevel::None) {
- // Try mul.wide combining at OptLevel > 0
- if (SDValue Ret = TryMULWIDECombine(N, DCI))
- return Ret;
- }
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
- return SDValue();
+ if (SDValue Ret = TryMULWIDECombine(N, DCI))
+ return Ret;
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ return PerformMULCombineWithOperands(N, N0, N1, DCI);
}
/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
new file mode 100644
index 0000000000000..fba389afdca39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -0,0 +1,136 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 | %ptxas-verify %}
+
+define i32 @test1(i32 %n, i32 %m) {
+;
+; CHECK-LABEL: test1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test1_param_1];
+; CHECK-NEXT: mad.lo.s32 %r3, %r2, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 1
+ %mul = mul i32 %add, %m
+ ret i32 %mul
+}
+
+define i32 @test1_rev(i32 %n, i32 %m) {
+;
+; CHECK-LABEL: test1_rev(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test1_rev_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test1_rev_param_1];
+; CHECK-NEXT: mad.lo.s32 %r3, %r2, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 1
+ %mul = mul i32 %m, %add
+ ret i32 %mul
+}
+
+; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK-LABEL: test2(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test2_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test2_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test2_param_2];
+; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT: selp.b32 %r5, %r2, %r4, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 1
+ %cond = icmp slt i32 %s, 1
+ %sel = select i1 %cond, i32 1, i32 %add
+ %mul = mul i32 %sel, %m
+ ret i32 %mul
+}
+
+;; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2_rev1(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK-LABEL: test2_rev1(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test2_rev1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test2_rev1_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test2_rev1_param_2];
+; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT: selp.b32 %r5, %r4, %r2, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 1
+ %cond = icmp slt i32 %s, 1
+ %sel = select i1 %cond, i32 %add, i32 1
+ %mul = mul i32 %sel, %m
+ ret i32 %mul
+}
+
+;; Transpose (mul (select)) if it can then be folded to mad
+define i32 @test2_rev2(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK-LABEL: test2_rev2(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test2_rev2_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test2_rev2_param_1];
+; CHECK-NEXT: ld.param.u32 %r3, [test2_rev2_param_2];
+; CHECK-NEXT: setp.lt.s32 %p1, %r3, 1;
+; CHECK-NEXT: mad.lo.s32 %r4, %r2, %r1, %r2;
+; CHECK-NEXT: selp.b32 %r5, %r4, %r2, %p1;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 1
+ %cond = icmp slt i32 %s, 1
+ %sel = select i1 %cond, i32 %add, i32 1
+ %mul = mul i32 %m, %sel
+ ret i32 %mul
+}
+
+;; Leave (mul (select)) intact if it transposing is not profitable
+define i32 @test3(i32 %n, i32 %m, i32 %s) {
+;
+; CHECK-LABEL: test3(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test3_param_0];
+; CHECK-NEXT: add.s32 %r2, %r1, 3;
+; CHECK-NEXT: ld.param.u32 %r3, [test3_param_1];
+; CHECK-NEXT: ld.param.u32 %r4, [test3_param_2];
+; CHECK-NEXT: setp.lt.s32 %p1, %r4, 1;
+; CHECK-NEXT: selp.b32 %r5, 1, %r2, %p1;
+; CHECK-NEXT: mul.lo.s32 %r6, %r5, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], %r6;
+; CHECK-NEXT: ret;
+ %add = add i32 %n, 3
+ %cond = icmp slt i32 %s, 1
+ %sel = select i1 %cond, i32 1, i32 %add
+ %mul = mul i32 %sel, %m
+ ret i32 %mul
+}
More information about the llvm-commits
mailing list