[llvm] [AArch64] Fix SDNode type mismatches between *.td files and ISel (PR #116523)

Sergei Barannikov via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 18:16:52 PST 2024


https://github.com/s-barannikov updated https://github.com/llvm/llvm-project/pull/116523

>From 07c34629204e93d1edce13cc3e18b2dc94e39cdc Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Sun, 17 Nov 2024 06:30:40 +0300
Subject: [PATCH 1/2] [AArch64] Fix SDNode type mismatches between *.td files
 and ISel

MRS, PTEST and FP comparisons were missing "flags" result,
and were sometimes created with invalid types (f32, Glue, Other).
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 12 +++++-------
 .../lib/Target/AArch64/AArch64InstrFormats.td | 12 ++++++------
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   | 19 ++++++++++++++-----
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  6 +++++-
 llvm/lib/Target/AArch64/SVEInstrFormats.td    |  4 ++--
 5 files changed, 32 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2732e495c552a6..79f534cd0924cd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3547,11 +3547,10 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
     RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
                       {LHS.getValue(1), RHS});
     Chain = RHS.getValue(1);
-    VT = MVT::f32;
   }
   unsigned Opcode =
       IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
-  return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
+  return DAG.getNode(Opcode, dl, {MVT::i32, MVT::Other}, {Chain, LHS, RHS});
 }
 
 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
@@ -3564,9 +3563,8 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
     if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
       LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
       RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
-      VT = MVT::f32;
     }
-    return DAG.getNode(AArch64ISD::FCMP, dl, VT, LHS, RHS);
+    return DAG.getNode(AArch64ISD::FCMP, dl, MVT::i32, LHS, RHS);
   }
 
   // The CMP instruction is just an alias for SUBS, and representing it as
@@ -21630,7 +21628,7 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   // Set condition code (CC) flags.
   SDValue Test = DAG.getNode(
       Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
-      DL, MVT::Other, Pg, Op);
+      DL, MVT::i32, Pg, Op);
 
   // Convert CC to integer based on requested condition.
   // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -26436,8 +26434,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
                                                   : AArch64SysReg::RNDRRS);
       SDLoc DL(N);
       SDValue A = DAG.getNode(
-          AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
-          N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64));
+          AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::i32, MVT::Other),
+          N->getOperand(0), DAG.getConstant(Register, DL, MVT::i32));
       SDValue B = DAG.getNode(
           AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
           DAG.getConstant(0, DL, MVT::i32),
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 15d4e93b915c14..242aea5fbb0142 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -5911,34 +5911,34 @@ multiclass FPComparison<bit signalAllNans, string asm,
                         SDPatternOperator OpNode = null_frag> {
   let Defs = [NZCV] in {
   def Hrr : BaseTwoOperandFPComparison<signalAllNans, FPR16, asm,
-      [(OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm))]> {
+      [(set NZCV, (OpNode (f16 FPR16:$Rn), (f16 FPR16:$Rm)))]> {
     let Inst{23-22} = 0b11;
     let Predicates = [HasFullFP16];
   }
 
   def Hri : BaseOneOperandFPComparison<signalAllNans, FPR16, asm,
-      [(OpNode (f16 FPR16:$Rn), fpimm0)]> {
+      [(set NZCV, (OpNode (f16 FPR16:$Rn), fpimm0))]> {
     let Inst{23-22} = 0b11;
     let Predicates = [HasFullFP16];
   }
 
   def Srr : BaseTwoOperandFPComparison<signalAllNans, FPR32, asm,
-      [(OpNode FPR32:$Rn, (f32 FPR32:$Rm))]> {
+      [(set NZCV, (OpNode FPR32:$Rn, (f32 FPR32:$Rm)))]> {
     let Inst{23-22} = 0b00;
   }
 
   def Sri : BaseOneOperandFPComparison<signalAllNans, FPR32, asm,
-      [(OpNode (f32 FPR32:$Rn), fpimm0)]> {
+      [(set NZCV, (OpNode (f32 FPR32:$Rn), fpimm0))]> {
     let Inst{23-22} = 0b00;
   }
 
   def Drr : BaseTwoOperandFPComparison<signalAllNans, FPR64, asm,
-      [(OpNode FPR64:$Rn, (f64 FPR64:$Rm))]> {
+      [(set NZCV, (OpNode FPR64:$Rn, (f64 FPR64:$Rm)))]> {
     let Inst{23-22} = 0b01;
   }
 
   def Dri : BaseOneOperandFPComparison<signalAllNans, FPR64, asm,
-      [(OpNode (f64 FPR64:$Rn), fpimm0)]> {
+      [(set NZCV, (OpNode (f64 FPR64:$Rn), fpimm0))]> {
     let Inst{23-22} = 0b01;
   }
   } // Defs = [NZCV]
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index c8d4291c5f2802..66ab2402e564b3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -438,9 +438,13 @@ def SDT_AArch64FCCMP : SDTypeProfile<1, 5,
                                       SDTCisInt<3>,
                                       SDTCisInt<4>,
                                       SDTCisVT<5, i32>]>;
-def SDT_AArch64FCmp   : SDTypeProfile<0, 2,
-                                   [SDTCisFP<0>,
-                                    SDTCisSameAs<0, 1>]>;
+
+def SDT_AArch64FCmp   : SDTypeProfile<1, 2, [
+  SDTCisVT<0, i32>,   // out flags
+  SDTCisFP<1>,        // lhs
+  SDTCisSameAs<2, 1>  // rhs
+]>;
+
 def SDT_AArch64Dup   : SDTypeProfile<1, 1, [SDTCisVec<0>]>;
 def SDT_AArch64DupLane   : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisInt<2>]>;
 def SDT_AArch64Insr  : SDTypeProfile<1, 2, [SDTCisVec<0>]>;
@@ -992,8 +996,13 @@ def AArch64probedalloca
              [SDNPHasChain, SDNPMayStore]>;
 
 def AArch64mrs : SDNode<"AArch64ISD::MRS",
-                        SDTypeProfile<1, 1, [SDTCisVT<0, i64>, SDTCisVT<1, i32>]>,
-                        [SDNPHasChain, SDNPOutGlue]>;
+  SDTypeProfile<2, 1, [
+    SDTCisVT<0, i64>, // result
+    SDTCisVT<1, i32>, // out flags
+    SDTCisVT<2, i32>  // system register number
+  ]>,
+  [SDNPHasChain]
+>;
 
 def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisInt<2>]>;
 def AArch64rshrnb : SDNode<"AArch64ISD::RSHRNB_I", SD_AArch64rshrnb>;
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 8791ce6266c86c..e575fbb36dc7b8 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -376,7 +376,11 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
      (AArch64fadda_p_node (SVEAllActive), node:$op2,
              (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
 
-def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
+def SDT_AArch64PTest : SDTypeProfile<1, 2, [
+  SDTCisVT<0, i32>,  // out flags
+  SDTCisVec<1>,      // governing predicate
+  SDTCisSameAs<2, 1> // source predicate
+]>;
 def AArch64ptest     : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
 def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
 
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 60705e2b6d4e7d..1ddb913f013f5e 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -836,7 +836,7 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
 : I<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
   asm, "\t$Pg, $Pn",
   "",
-  [(op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, Sched<[]> {
+  [(set NZCV, (op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn)))]>, Sched<[]> {
   bits<4> Pg;
   bits<4> Pn;
   let Inst{31-24} = 0b00100101;
@@ -860,7 +860,7 @@ multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
 
   let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
   def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
-                    [(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
+                    [(set NZCV, (op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn)))]>,
              PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
   }
 }

>From acce415bc453b6b379ccbfee508029106c659fcf Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <s.barannikov at module.ru>
Date: Wed, 20 Nov 2024 05:13:09 +0300
Subject: [PATCH 2/2] More fixes

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++----
 llvm/lib/Target/AArch64/AArch64InstrInfo.td     | 3 +--
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 79f534cd0924cd..203066961f4a12 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -13720,11 +13720,11 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
   unsigned NumElts = VT.getVectorNumElements();
   unsigned EltSize = VT.getScalarSizeInBits();
   if (isREVMask(ShuffleMask, EltSize, NumElts, 64))
-    return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1, V2);
+    return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1);
   if (isREVMask(ShuffleMask, EltSize, NumElts, 32))
-    return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1, V2);
+    return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1);
   if (isREVMask(ShuffleMask, EltSize, NumElts, 16))
-    return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
+    return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1);
 
   if (((NumElts == 8 && EltSize == 16) || (NumElts == 16 && EltSize == 8)) &&
       ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size())) {
@@ -15741,7 +15741,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
     if (IsZero)
       return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
     if (IsMinusOne)
-      return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS, RHS);
+      return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
     return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
   case AArch64CC::LE:
     if (IsZero)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 66ab2402e564b3..6bffffdcb538dd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -885,8 +885,7 @@ def AArch64uitof: SDNode<"AArch64ISD::UITOF", SDT_AArch64ITOF>;
 
 def AArch64tlsdesc_callseq : SDNode<"AArch64ISD::TLSDESC_CALLSEQ",
                                     SDT_AArch64TLSDescCallSeq,
-                                    [SDNPInGlue, SDNPOutGlue, SDNPHasChain,
-                                     SDNPVariadic]>;
+                                    [SDNPOutGlue, SDNPHasChain, SDNPVariadic]>;
 
 
 def AArch64WrapperLarge : SDNode<"AArch64ISD::WrapperLarge",



More information about the llvm-commits mailing list