[llvm] 9b10512 - [DAG] Use SDValue for PatFrag checks (#137519)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 1 00:59:03 PDT 2025
Author: David Green
Date: 2025-05-01T08:58:59+01:00
New Revision: 9b1051281e439fcf6f6ccf03766c5bcf04ceec4b
URL: https://github.com/llvm/llvm-project/commit/9b1051281e439fcf6f6ccf03766c5bcf04ceec4b
DIFF: https://github.com/llvm/llvm-project/commit/9b1051281e439fcf6f6ccf03766c5bcf04ceec4b.diff
LOG: [DAG] Use SDValue for PatFrag checks (#137519)
If the SDNode is used it can pick up the wrong results number, for
example looking at the known bits of the first result where it should be
looking at the second. The SDValue is already present as the
SelectCodeCommon checks move from parent to child, pass the SDValue
through to CheckNodePredicate as Op so that it can use it if necessary.
SDNode *N is still generated, keeping most PatFrags the same.
Fixes #137274
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAGISel.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/lib/Target/AMDGPU/SIInstrInfo.td
llvm/lib/Target/AMDGPU/SIInstructions.td
llvm/lib/Target/ARM/ARMInstrInfo.td
llvm/lib/Target/RISCV/RISCVInstrInfo.td
llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
llvm/lib/Target/X86/X86InstrSSE.td
llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll
llvm/test/TableGen/HasNoUse.td
llvm/test/TableGen/address-space-patfrags.td
llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index 55f8f19d437a0..7a41e09b6aeaf 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -426,7 +426,7 @@ class SelectionDAGISel {
/// It runs node predicate number PredNo and returns true if it succeeds or
/// false if it fails. The number is a private implementation
/// detail to the code tblgen produces.
- virtual bool CheckNodePredicate(SDNode *N, unsigned PredNo) const {
+ virtual bool CheckNodePredicate(SDValue Op, unsigned PredNo) const {
llvm_unreachable("Tblgen should generate the implementation of this!");
}
@@ -436,7 +436,7 @@ class SelectionDAGISel {
/// false if it fails. The number is a private implementation detail to the
/// code tblgen produces.
virtual bool CheckNodePredicateWithOperands(
- SDNode *N, unsigned PredNo,
+ SDValue Op, unsigned PredNo,
const SmallVectorImpl<SDValue> &Operands) const {
llvm_unreachable("Tblgen should generate the implementation of this!");
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 81f5dd2ed2571..1bc30336a02bf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2897,11 +2897,11 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
- SDNode *N) {
+ SDValue Op) {
unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
? MatcherTable[MatcherIndex++]
: Opcode - SelectionDAGISel::OPC_CheckPredicate0;
- return SDISel.CheckNodePredicate(N, PredNo);
+ return SDISel.CheckNodePredicate(Op, PredNo);
}
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
@@ -3062,7 +3062,7 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
- Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
+ Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N);
return Index;
case SelectionDAGISel::OPC_CheckOpcode:
Result = !::CheckOpcode(Table, Index, N.getNode());
@@ -3574,8 +3574,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
case OPC_CheckPredicate:
- if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
- N.getNode()))
+ if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this, N))
break;
continue;
case OPC_CheckPredicateWithOperands: {
@@ -3586,7 +3585,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
Operands.push_back(RecordedNodes[MatcherTable[MatcherIndex++]].first);
unsigned PredNo = MatcherTable[MatcherIndex++];
- if (!CheckNodePredicateWithOperands(N.getNode(), PredNo, Operands))
+ if (!CheckNodePredicateWithOperands(N, PredNo, Operands))
break;
continue;
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f7b13092821d6..bee86aa86ec37 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -683,24 +683,24 @@ defm trunc_masked_scatter_i32 : masked_gather_scatter<trunc_masked_scatter_i32>;
// top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
def top16Zero: PatLeaf<(i32 GPR32:$src), [{
- return SDValue(N,0)->getValueType(0) == MVT::i32 &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 16));
+ return Op.getValueType() == MVT::i32 &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 16));
}]>;
// top32Zero - answer true if the upper 32 bits of $src are 0, false otherwise
def top32Zero: PatLeaf<(i64 GPR64:$src), [{
- return SDValue(N,0)->getValueType(0) == MVT::i64 &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(64, 32));
+ return Op.getValueType() == MVT::i64 &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(64, 32));
}]>;
// topbitsallzero - Return true if all bits except the lowest bit are known zero
def topbitsallzero32: PatLeaf<(i32 GPR32:$src), [{
- return SDValue(N,0)->getValueType(0) == MVT::i32 &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 31));
+ return Op.getValueType() == MVT::i32 &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 31));
}]>;
def topbitsallzero64: PatLeaf<(i64 GPR64:$src), [{
- return SDValue(N,0)->getValueType(0) == MVT::i64 &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(64, 63));
+ return Op.getValueType() == MVT::i64 &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(64, 63));
}]>;
// Node definitions.
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 5d837d853ac98..adc7cd0b14af6 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -969,7 +969,7 @@ def MFMALdScaleXForm : SDNodeXForm<timm, [{
def is_canonicalized : PatLeaf<(fAny srcvalue:$src), [{
const SITargetLowering &Lowering =
*static_cast<const SITargetLowering *>(getTargetLowering());
- return Lowering.isCanonicalized(*CurDAG, SDValue(N, 0));
+ return Lowering.isCanonicalized(*CurDAG, Op);
}]> {
let GISelPredicateCode = [{
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index a144ae2104da6..f5c6d47369781 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -3861,7 +3861,7 @@ def : AMDGPUPat <
>;
def uint5Bits : PatLeaf<(i32 VGPR_32:$width), [{
- return CurDAG->computeKnownBits(SDValue(N, 0)).countMaxActiveBits() <= 5;
+ return CurDAG->computeKnownBits(Op).countMaxActiveBits() <= 5;
}]>;
// x & (-1 >> (bitwidth - y))
diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index c682f597401ec..1f5ba998970fc 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -421,7 +421,7 @@ def imm16_31 : ImmLeaf<i32, [{
// sext_16_node predicate - True if the SDNode is sign-extended 16 or more bits.
def sext_16_node : PatLeaf<(i32 GPR:$a), [{
- return CurDAG->ComputeNumSignBits(SDValue(N,0)) >= 17;
+ return CurDAG->ComputeNumSignBits(Op) >= 17;
}]>;
def sext_bottom_16 : PatFrag<(ops node:$a),
@@ -451,14 +451,14 @@ def lo16AllZero : PatLeaf<(i32 imm), [{
// top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise
def top16Zero: PatLeaf<(i32 GPR:$src), [{
- return !SDValue(N,0)->getValueType(0).isVector() &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 16));
+ return !Op.getValueType().isVector() &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 16));
}]>;
// topbitsallzero - Return true if all bits except the lowest bit are known zero
def topbitsallzero32 : PatLeaf<(i32 GPRwithZR:$src), [{
- return SDValue(N,0)->getValueType(0) == MVT::i32 &&
- CurDAG->MaskedValueIsZero(SDValue(N,0), APInt::getHighBitsSet(32, 31));
+ return Op.getValueType() == MVT::i32 &&
+ CurDAG->MaskedValueIsZero(Op, APInt::getHighBitsSet(32, 31));
}]>;
class BinOpFrag<dag res> : PatFrag<(ops node:$LHS, node:$RHS), res>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 7cd36aa46efbe..4a4290483e94b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -1337,7 +1337,7 @@ def ext_oneuse : unop_oneuse<ext>;
def fpext_oneuse : unop_oneuse<any_fpextend>;
def 33signbits_node : PatLeaf<(i64 GPR:$src), [{
- return CurDAG->ComputeNumSignBits(SDValue(N, 0)) > 32;
+ return CurDAG->ComputeNumSignBits(Op) > 32;
}]>;
class immop_oneuse<ImmLeaf leaf> : PatLeaf<(leaf), [{
@@ -1977,7 +1977,7 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)),
class binop_allhusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs),
(XLenVT (operator node:$lhs, node:$rhs)), [{
- return hasAllHUsers(Node);
+ return hasAllHUsers(N);
}]> {
let GISelPredicateCode = [{ return hasAllHUsers(MI); }];
}
@@ -1987,14 +1987,14 @@ class binop_allhusers<SDPatternOperator operator>
class binop_allwusers<SDPatternOperator operator>
: PatFrag<(ops node:$lhs, node:$rhs), (i64 (operator node:$lhs, node:$rhs)),
[{
- return hasAllWUsers(Node);
+ return hasAllWUsers(N);
}]> {
let GISelPredicateCode = [{ return hasAllWUsers(MI); }];
}
def sexti32_allwusers : PatFrag<(ops node:$src),
(sext_inreg node:$src, i32), [{
- return hasAllWUsers(Node);
+ return hasAllWUsers(N);
}]>;
def ImmSExt32 : SDNodeXForm<imm, [{
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
index 415e802951a94..b5e723e2a48d3 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
@@ -239,7 +239,7 @@ def TypeIndex : Operand<i32>;
// TODO: Find more places to use this.
def bool_node : PatLeaf<(i32 I32:$cond), [{
- return CurDAG->computeKnownBits(SDValue(N, 0)).countMinLeadingZeros() == 31;
+ return CurDAG->computeKnownBits(Op).countMinLeadingZeros() == 31;
}]>;
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index 49a62fd3422d0..1acc0cd8da205 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -5705,7 +5705,7 @@ let Predicates = [UseSSE41, OptForSize] in {
// commuting would change which operand is inverted.
def X86ptest_commutable : PatFrag<(ops node:$src1, node:$src2),
(X86ptest node:$src1, node:$src2), [{
- return onlyUsesZeroFlag(SDValue(Node, 0));
+ return onlyUsesZeroFlag(SDValue(N, 0));
}]>;
// ptest instruction we'll lower to this in X86ISelLowering primarily from
@@ -5772,7 +5772,7 @@ multiclass avx_bittest<bits<8> opc, string OpcodeStr, RegisterClass RC,
// used, commuting would change which operand is inverted.
def X86testp_commutable : PatFrag<(ops node:$src1, node:$src2),
(X86testp node:$src1, node:$src2), [{
- return onlyUsesZeroFlag(SDValue(Node, 0));
+ return onlyUsesZeroFlag(SDValue(N, 0));
}]>;
let Defs = [EFLAGS], Predicates = [HasAVX] in {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll
index 3c6e4a1d2e130..8de1fc5762c15 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll
@@ -2098,3 +2098,19 @@ B:
%t = icmp eq i64 0, %3
br i1 %t, label %A, label %B
}
+
+define i64 @pr137274(ptr %ptr) {
+; CHECK-LABEL: pr137274:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ldr x8, [x0]
+; CHECK-NEXT: ldr w9, [x8, #8]!
+; CHECK-NEXT: mul x0, x8, x9
+; CHECK-NEXT: ret
+ %l0 = load i64, ptr %ptr, align 8
+ %add = add i64 %l0, 8
+ %i1 = inttoptr i64 %add to ptr
+ %l2 = load i32, ptr %i1, align 4
+ %conv = zext i32 %l2 to i64
+ %mul = mul i64 %add, %conv
+ ret i64 %mul
+}
diff --git a/llvm/test/TableGen/HasNoUse.td b/llvm/test/TableGen/HasNoUse.td
index 0947be11caa4c..d51fa9ef07230 100644
--- a/llvm/test/TableGen/HasNoUse.td
+++ b/llvm/test/TableGen/HasNoUse.td
@@ -10,7 +10,7 @@ def NO_RET_ATOMIC_ADD : I<(outs), (ins GPR32Op:$src0, GPR32Op:$src1), []>;
// SDAG: case 0: {
// SDAG-NEXT: // Predicate_atomic_load_add_no_ret_i32
-// SDAG-NEXT: SDNode *N = Node;
+// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: if (cast<MemSDNode>(N)->getMemoryVT() != MVT::i32) return false;
// SDAG-NEXT: if (N->hasAnyUseOfValue(0)) return false;
diff --git a/llvm/test/TableGen/address-space-patfrags.td b/llvm/test/TableGen/address-space-patfrags.td
index a2611df048b06..2aaa3451bdee1 100644
--- a/llvm/test/TableGen/address-space-patfrags.td
+++ b/llvm/test/TableGen/address-space-patfrags.td
@@ -49,7 +49,7 @@ def inst_d : Instruction {
// SDAG: case 0: {
// SDAG-NEXT: // Predicate_pat_frag_b
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
-// SDAG-NEXT: SDNode *N = Node;
+// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
// SDAG-NEXT: if (AddrSpace != 123 && AddrSpace != 455)
@@ -71,7 +71,7 @@ def : Pat <
// SDAG: case 4: {
// SDAG: // Predicate_pat_frag_a
-// SDAG-NEXT: SDNode *N = Node;
+// SDAG-NEXT: SDNode *N = Op.getNode();
// SDAG-NEXT: (void)N;
// SDAG-NEXT: if (cast<MemSDNode>(N)->getAlign() < Align(2))
// SDAG-NEXT: return false;
diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
index 20b313d4428db..febcb1fd662f5 100644
--- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
@@ -1375,11 +1375,11 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {
std::string Result = (" " + getImmType() + " Imm = ").str();
if (immCodeUsesAPFloat())
- Result += "cast<ConstantFPSDNode>(Node)->getValueAPF();\n";
+ Result += "cast<ConstantFPSDNode>(Op.getNode())->getValueAPF();\n";
else if (immCodeUsesAPInt())
- Result += "Node->getAsAPIntVal();\n";
+ Result += "Op->getAsAPIntVal();\n";
else
- Result += "cast<ConstantSDNode>(Node)->getSExtValue();\n";
+ Result += "cast<ConstantSDNode>(Op.getNode())->getSExtValue();\n";
return Result + ImmCode;
}
@@ -1410,9 +1410,9 @@ std::string TreePredicateFn::getCodeToRunOnSDNode() const {
std::string Result;
if (ClassName == "SDNode")
- Result = " SDNode *N = Node;\n";
+ Result = " SDNode *N = Op.getNode();\n";
else
- Result = " auto *N = cast<" + ClassName.str() + ">(Node);\n";
+ Result = " auto *N = cast<" + ClassName.str() + ">(Op.getNode());\n";
return (Twine(Result) + " (void)N;\n" + getPredCode()).str();
}
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index 57997a6b0e4e0..8b0f48aca259c 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -1149,11 +1149,11 @@ void MatcherTableEmitter::EmitPredicateFunctions(raw_ostream &OS) {
// Emit Node predicates.
EmitNodePredicatesFunction(
- NodePredicates, "CheckNodePredicate(SDNode *Node, unsigned PredNo) const",
+ NodePredicates, "CheckNodePredicate(SDValue Op, unsigned PredNo) const",
OS);
EmitNodePredicatesFunction(
NodePredicatesWithOperands,
- "CheckNodePredicateWithOperands(SDNode *Node, unsigned PredNo, "
+ "CheckNodePredicateWithOperands(SDValue Op, unsigned PredNo, "
"const SmallVectorImpl<SDValue> &Operands) const",
OS);
More information about the llvm-commits
mailing list