[llvm] [NVPTX][SelectionDAG] Add IMAD combine rules + infra to disable default SelectionDAG rules for testing (PR #121724)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 5 19:09:55 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: None (peterbell10)
<details>
<summary>Changes</summary>
I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/45W3Wcnxz.
This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified.
To fix this, I add some combiner patterns for IMAD. In particular, this PR adds:
```
mad x 1 y => add x y
mad x -1 y => sub y x
mad x 0 y => y
mad x y 0 => mul x y
mad c0 c1 z => add z (c0 * c1)
```
Another option might be to remove `NVPTXISD::IMAD` and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention. However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.
I found testing this change to be quite tricky as there is no mad intrinsic so we have to write `add (mul x y) z` in llvm IR which will be simplified by `SelectionDAG::getNode` before we even get the `DAGCombiner`. So, I've also added a couple of debug flags to disable default simplifications:
```
--selectiondag-simplify-nodes=false # disables simplifications in SelectionDAG::getNode
--combiner-generic-combines=false # disables generic DAGCombiner patterns
```
For the `SelectionDAG` flag there is a significant number of lines changed, but it's mostly removing duplicated code that had been copy-pasted to all of the `getNode` implementations and factoring it out into a private `getNodeImpl` so I think this is okay. I'd be happy to split this into a separate PR if required.
---
Full diff: https://github.com/llvm/llvm-project/pull/121724.diff
7 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+5)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+9-1)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+51-138)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+49)
- (added) llvm/test/CodeGen/NVPTX/combine-mad-only.ll (+87)
- (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+20)
- (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+1-1)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ff7caec41855fd..3a015c8df2066a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2460,6 +2460,11 @@ class SelectionDAG {
SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
void *&InsertPos);
+ SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+ SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, SDVTList VTs,
+ ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+
/// Maps to auto-CSE operations.
std::vector<CondCodeSDNode*> CondCodeNodes;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b2501591c81a3..6d75809cdaf69f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -153,6 +153,13 @@ static cl::opt<bool> EnableVectorFCopySignExtendRound(
"combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
cl::desc(
"Enable merging extends and rounds into FCOPYSIGN on vector types"));
+
+static cl::opt<bool>
+ EnableGenericCombines("combiner-generic-combines", cl::Hidden,
+ cl::init(true),
+ cl::desc("Enable generic DAGCombine patterns. Useful "
+ "for testing target-specific combines."));
+
namespace {
class DAGCombiner {
@@ -251,7 +258,8 @@ namespace {
: DAG(D), TLI(D.getTargetLoweringInfo()),
STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
ForCodeSize = DAG.shouldOptForSize();
- DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
+ DisableGenericCombines = !EnableGenericCombines ||
+ (STI && STI->disableGenericCombines(OptLevel));
MaximumLegalStoreInBits = 0;
// We use the minimum store size here, since that's all we can guarantee
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 10e8ba93359fbd..6a3799e02edd94 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -115,6 +115,10 @@ static cl::opt<unsigned>
MaxSteps("has-predecessor-max-steps", cl::Hidden, cl::init(8192),
cl::desc("DAG combiner limit number of steps when searching DAG "
"for predecessor nodes"));
+static cl::opt<bool> EnableSimplifyNodes(
+ "selectiondag-simplify-nodes", cl::Hidden, cl::init(true),
+ cl::desc("Enable SelectionDAG::getNode simplifications. Useful for testing "
+ "DAG combines."));
static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
@@ -6157,23 +6161,46 @@ static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
}
/// Gets or creates the specified node.
-SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops,
+ const SDNodeFlags Flags) {
SDVTList VTs = getVTList(VT);
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, {});
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
- return SDValue(E, 0);
+ return getNodeImpl(Opcode, DL, VTs, Ops, Flags);
+}
- auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- CSEMap.InsertNode(N, IP);
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL,
+ SDVTList VTs, ArrayRef<SDValue> Ops,
+ const SDNodeFlags Flags) {
+ SDNode *N;
+ // Don't CSE glue-producing nodes
+ if (VTs.VTs[VTs.NumVTs - 1] != MVT::Glue) {
+ FoldingSetNodeID ID;
+ AddNodeIDNode(ID, Opcode, VTs, Ops);
+ void *IP = nullptr;
+ if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+ E->intersectFlagsWith(Flags);
+ return SDValue(E, 0);
+ }
+
+ N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+ createOperands(N, Ops);
+ CSEMap.InsertNode(N, IP);
+ } else {
+ N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+ createOperands(N, Ops);
+ }
+ N->setFlags(Flags);
InsertNode(N);
SDValue V = SDValue(N, 0);
NewSDValueDbgMsg(V, "Creating new node: ", this);
return V;
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+ return getNodeImpl(Opcode, DL, VT, {}, SDNodeFlags{});
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1) {
SDNodeFlags Flags;
@@ -6185,6 +6212,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, const SDNodeFlags Flags) {
assert(N1.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
// Constant fold unary operations with a vector integer or float operand.
switch (Opcode) {
@@ -6501,31 +6530,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1};
- if (VT != MVT::Glue) { // Don't CSE glue producing nodes
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
}
static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
@@ -7219,6 +7224,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(N1.getOpcode() != ISD::DELETED_NODE &&
N2.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
canonicalizeCommutativeBinop(Opcode, N1, N2);
@@ -7665,32 +7672,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
}
- // Memoize this node if possible.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1, N2};
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -7708,6 +7690,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
N2.getOpcode() != ISD::DELETED_NODE &&
N3.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
+
// Perform various simplifications.
switch (Opcode) {
case ISD::FMA:
@@ -7862,33 +7847,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
}
-
- // Memoize node if it doesn't produce a glue result.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
- SDValue Ops[] = {N1, N2, N3};
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- N->setFlags(Flags);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- InsertNode(N);
- SDValue V = SDValue(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -10343,6 +10302,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(Op.getOpcode() != ISD::DELETED_NODE &&
"Operand is DELETED_NODE!");
#endif
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VT, Ops, Flags);
switch (Opcode) {
default: break;
@@ -10411,34 +10372,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
- // Memoize nodes.
- SDNode *N;
- SDVTList VTs = getVTList(VT);
-
- if (VT != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTs, Ops);
- void *IP = nullptr;
-
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
-
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
- createOperands(N, Ops);
- }
-
- N->setFlags(Flags);
- InsertNode(N);
- SDValue V(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VT, Ops, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
@@ -10458,6 +10392,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
if (VTList.NumVTs == 1)
return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
+ if (!EnableSimplifyNodes)
+ return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
#ifndef NDEBUG
for (const auto &Op : Ops)
@@ -10622,30 +10558,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
#endif
}
- // Memoize the node unless it returns a glue result.
- SDNode *N;
- if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
- FoldingSetNodeID ID;
- AddNodeIDNode(ID, Opcode, VTList, Ops);
- void *IP = nullptr;
- if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
- E->intersectFlagsWith(Flags);
- return SDValue(E, 0);
- }
-
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
- createOperands(N, Ops);
- CSEMap.InsertNode(N, IP);
- } else {
- N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
- createOperands(N, Ops);
- }
-
- N->setFlags(Flags);
- InsertNode(N);
- SDValue V(N, 0);
- NewSDValueDbgMsg(V, "Creating new node: ", this);
- return V;
+ return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..c4529c9151bc2b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
}
+static SDValue
+PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
+ EVT VT = N0->getValueType(0);
+ SDLoc DL(N);
+ SDNodeFlags Flags = N->getFlags();
+
+ // mad x 1 y => add x y
+ if (N1C && N1C->isOne())
+ return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
+
+ // mad x -1 y => sub y x
+ if (N1C && N1C->isAllOnes()) {
+ Flags.setNoUnsignedWrap(false);
+ return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
+ }
+
+ // mad x 0 y => y
+ if (N1C && N1C->isZero())
+ return N2;
+
+ // mad x y 0 => mul x y
+ if (N2C && N2C->isZero())
+ return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
+
+ // mad c0 c1 x => add x (c0*c1)
+ if (SDValue C =
+ DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
+ return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
+
+ return {};
+}
+
+static SDValue PerformIMADCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
+ if (res)
+ return res;
+
+ return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5198,6 +5245,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformVSELECTCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
+ case NVPTXISD::IMAD:
+ return PerformIMADCombine(N, DCI);
}
return SDValue();
}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
new file mode 100644
index 00000000000000..fb4bcc39b5a64d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+
+;; mad x 1 y => add y x
+define i32 @test_mad_mul_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_mul_1_param_1];
+; CHECK-NEXT: add.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, 1
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x -1 y => sub y x
+define i32 @test_mad_mul_neg_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_neg_1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_neg_1_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_mul_neg_1_param_1];
+; CHECK-NEXT: sub.s32 %r3, %r2, %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, -1
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x 0 y => y
+define i32 @test_mad_mul_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_0(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_mul_0_param_1];
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, 0
+ %add = add i32 %mul, %y
+ ret i32 %add
+}
+
+;; mad x y 0 => mul x y
+define i32 @test_mad_add_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_add_0(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_add_0_param_0];
+; CHECK-NEXT: ld.param.u32 %r2, [test_mad_add_0_param_1];
+; CHECK-NEXT: mul.lo.s32 %r3, %r1, %r2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT: ret;
+ %mul = mul i32 %x, %y
+ %add = add i32 %mul, 0
+ ret i32 %add
+}
+
+;; mad c0 c1 x => add x (c0*c1)
+define i32 @test_mad_fold_mul(i32 %x) {
+; CHECK-LABEL: test_mad_fold_mul(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_mul_param_0];
+; CHECK-NEXT: add.s32 %r2, %r1, 12;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %mul = mul i32 4, 3
+ %add = add i32 %mul, %x
+ ret i32 %add
+}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 1b22cfde39725f..7d523a835a1f3f 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -183,3 +183,23 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
%add = add i32 %c, %sel
ret i32 %add
}
+
+;; This case relies on mad x 1 y => add x y, previously we emit:
+;; mad.lo.s32 %r3, %r1, 1, %r2;
+define i32 @test_mad_fold(i32 %x) {
+; CHECK-LABEL: test_mad_fold(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_param_0];
+; CHECK-NEXT: mul.hi.s32 %r2, %r1, -2147221471;
+; CHECK-NEXT: add.s32 %r3, %r1, %r2;
+; CHECK-NEXT: shr.u32 %r4, %r3, 31;
+; CHECK-NEXT: shr.s32 %r5, %r3, 12;
+; CHECK-NEXT: add.s32 %r6, %r5, %r4;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT: ret;
+ %div = sdiv i32 %x, 8191
+ ret i32 %div
+}
diff --git a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
index 27a523b9dd91d2..de19d2983f3435 100644
--- a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
+++ b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
@@ -12,7 +12,7 @@
; CHECK-NOT: __local_depot
; CHECK-32: ld.param.u32 %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
-; CHECK-32-NEXT: mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
+; CHECK-32-NEXT: add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
; CHECK-32-NEXT: and.b32 %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
; CHECK-32-NEXT: alloca.u32 %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
; CHECK-32-NEXT: cvta.local.u32 %r[[ALLOCA]], %r[[ALLOCA]];
``````````
</details>
https://github.com/llvm/llvm-project/pull/121724
More information about the llvm-commits
mailing list