[llvm] 945a306 - [AArch64] Change aarch64_neon_pmull{,64} intrinsic ISel through a new

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 19 13:18:08 PDT 2022


Author: Mingming Liu
Date: 2022-08-19T13:17:13-07:00
New Revision: 945a3065015a62cf3b4ceabaa0755500fc7ddd71

URL: https://github.com/llvm/llvm-project/commit/945a3065015a62cf3b4ceabaa0755500fc7ddd71
DIFF: https://github.com/llvm/llvm-project/commit/945a3065015a62cf3b4ceabaa0755500fc7ddd71.diff

LOG: [AArch64] Change aarch64_neon_pmull{,64} intrinsic ISel through a new
SDNode.

How:
1) Add AArch64ISD::PMULL SDNode, and extend aarch64_neon_pmull intrinsic
   tablegen pattern for this SDNode.
2) For aarch64_neon_pmull64, canonicalize i64 operands to v1i64 vectors
   during legalization.
3) For {aarch64_neon_pmull, aarch64_neon_pmull64}, combine intrinsic to
   SDNode.

Why
1) Adding the SDNode makes it easier to canonicalize i64 inputs (required by
   aarch64_neon_pmull64) to vector inputs. Vector inputs carries lane
   information, which helps dag-combiner to combine nodes (e.g. rewrite to a
   better node to prepare for instruction selection) and instruction-selection
   to emit instructions that use higher-half inputs in place
   (i.e., no need to move lane 1 content to lane 0).
2) Using the SDNode for aarch64_neon_pmull64 is NFC, yet without this we
   have to move the definition of {PMULLv1i64, PMULLv2i64} out of its
   current group of records without gains.

Test cases are commented with what is being tested in
`aarch64-pmull2.ll` and `pmull-ldr-merge.ll` under directory
`llvm/test/CodeGen/AArch64`.

Differential Revision: https://reviews.llvm.org/D131047

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrFormats.td
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/aarch64-pmull2.ll
    llvm/test/CodeGen/AArch64/pmull-ldr-merge.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fb945a3c6abe..0cdfbbee0325 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2257,6 +2257,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::ST4LANEpost)
     MAKE_CASE(AArch64ISD::SMULL)
     MAKE_CASE(AArch64ISD::UMULL)
+    MAKE_CASE(AArch64ISD::PMULL)
     MAKE_CASE(AArch64ISD::FRECPE)
     MAKE_CASE(AArch64ISD::FRECPS)
     MAKE_CASE(AArch64ISD::FRSQRTE)
@@ -4203,22 +4204,19 @@ static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
   return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
 }
 
-static bool isOperandOfHigherHalf(SDValue &Op) {
+// Returns lane if Op extracts from a two-element vector and lane is constant
+// (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and None otherwise.
+static Optional<uint64_t> getConstantLaneNumOfExtractHalfOperand(SDValue &Op) {
   SDNode *OpNode = Op.getNode();
   if (OpNode->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
-    return false;
-
-  ConstantSDNode *C = dyn_cast<ConstantSDNode>(OpNode->getOperand(1));
-  if (!C || C->getZExtValue() != 1)
-    return false;
+    return None;
 
   EVT VT = OpNode->getOperand(0).getValueType();
+  ConstantSDNode *C = dyn_cast<ConstantSDNode>(OpNode->getOperand(1));
+  if (!VT.isFixedLengthVector() || VT.getVectorNumElements() != 2 || !C)
+    return None;
 
-  return VT.isFixedLengthVector() && VT.getVectorNumElements() == 2;
-}
-
-static bool areOperandsOfHigherHalf(SDValue &Op1, SDValue &Op2) {
-  return isOperandOfHigherHalf(Op1) && isOperandOfHigherHalf(Op2);
+  return C->getZExtValue();
 }
 
 static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG,
@@ -4562,27 +4560,59 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     }
   }
   case Intrinsic::aarch64_neon_pmull64: {
-    SDValue Op1 = Op.getOperand(1);
-    SDValue Op2 = Op.getOperand(2);
+    SDValue LHS = Op.getOperand(1);
+    SDValue RHS = Op.getOperand(2);
+
+    Optional<uint64_t> LHSLane = getConstantLaneNumOfExtractHalfOperand(LHS);
+    Optional<uint64_t> RHSLane = getConstantLaneNumOfExtractHalfOperand(RHS);
+
+    assert((!LHSLane || *LHSLane < 2) && "Expect lane to be None or 0 or 1");
+    assert((!RHSLane || *RHSLane < 2) && "Expect lane to be None or 0 or 1");
+
+    // 'aarch64_neon_pmull64' takes i64 parameters; while pmull/pmull2
+    // instructions execute on SIMD registers. So canonicalize i64 to v1i64,
+    // which ISel recognizes better. For example, generate a ldr into d*
+    // registers as opposed to a GPR load followed by a fmov.
+    auto TryVectorizeOperand =
+        [](SDValue N, Optional<uint64_t> NLane, Optional<uint64_t> OtherLane,
+           const SDLoc &dl, SelectionDAG &DAG) -> SDValue {
+      // If the operand is an higher half itself, rewrite it to
+      // extract_high_v2i64; this way aarch64_neon_pmull64 could
+      // re-use the dag-combiner function with aarch64_neon_{pmull,smull,umull}.
+      if (NLane && *NLane == 1)
+        return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
+                           N.getOperand(0), DAG.getConstant(1, dl, MVT::i64));
+
+      // Operand N is not a higher half but the other operand is.
+      if (OtherLane && *OtherLane == 1) {
+        // If this operand is a lower half, rewrite it to
+        // extract_high_v2i64(duplane(<2 x Ty>, 0)). This saves a roundtrip to
+        // align lanes of two operands. A roundtrip sequence (to move from lane
+        // 1 to lane 0) is like this:
+        //   mov x8, v0.d[1]
+        //   fmov d0, x8
+        if (NLane && *NLane == 0)
+          return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
+                             DAG.getNode(AArch64ISD::DUPLANE64, dl, MVT::v2i64,
+                                         N.getOperand(0),
+                                         DAG.getConstant(0, dl, MVT::i64)),
+                             DAG.getConstant(1, dl, MVT::i64));
+
+        // Otherwise just dup from main to all lanes.
+        return DAG.getNode(AArch64ISD::DUP, dl, MVT::v1i64, N);
+      }
 
-    // If both operands are higher half of two source SIMD & FP registers,
-    // ISel could make use of tablegen patterns to emit PMULL2. So do not
-    // legalize i64 to v1i64.
-    if (areOperandsOfHigherHalf(Op1, Op2))
-      return SDValue();
+      // Neither operand is an extract of higher half, so codegen may just use
+      // the non-high version of PMULL instruction. Use v1i64 to represent i64.
+      assert(N.getValueType() == MVT::i64 &&
+             "Intrinsic aarch64_neon_pmull64 requires i64 parameters");
+      return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, N);
+    };
 
-    // As a general convention, use "v1" types to represent scalar integer
-    // operations in vector registers. This helps ISel to make use of
-    // tablegen patterns and generate a load into SIMD & FP registers directly.
-    if (Op1.getValueType() == MVT::i64)
-      Op1 = DAG.getNode(ISD::BITCAST, dl, MVT::v1i64, Op1);
-    if (Op2.getValueType() == MVT::i64)
-      Op2 = DAG.getNode(ISD::BITCAST, dl, MVT::v1i64, Op2);
+    LHS = TryVectorizeOperand(LHS, LHSLane, RHSLane, dl, DAG);
+    RHS = TryVectorizeOperand(RHS, RHSLane, LHSLane, dl, DAG);
 
-    return DAG.getNode(
-        ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(),
-        DAG.getConstant(Intrinsic::aarch64_neon_pmull64, dl, MVT::i32), Op1,
-        Op2);
+    return DAG.getNode(AArch64ISD::PMULL, dl, Op.getValueType(), LHS, RHS);
   }
   case Intrinsic::aarch64_neon_smax:
     return DAG.getNode(ISD::SMAX, dl, Op.getValueType(),
@@ -16661,6 +16691,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
     return DAG.getNode(AArch64ISD::UMULL, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
   case Intrinsic::aarch64_neon_pmull:
+    return DAG.getNode(AArch64ISD::PMULL, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2));
   case Intrinsic::aarch64_neon_sqdmull:
     return tryCombineLongOpWithDup(IID, N, DCI, DAG);
   case Intrinsic::aarch64_neon_sqshl:
@@ -19800,6 +19832,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performUADDVCombine(N, DAG);
   case AArch64ISD::SMULL:
   case AArch64ISD::UMULL:
+  case AArch64ISD::PMULL:
     return tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG);
   case ISD::INTRINSIC_VOID:
   case ISD::INTRINSIC_W_CHAIN:

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index aabb9abe5fdf..004f4a520736 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -291,6 +291,8 @@ enum NodeType : unsigned {
   SMULL,
   UMULL,
 
+  PMULL,
+
   // Reciprocal estimates and steps.
   FRECPE,
   FRECPS,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index f908460df1fd..f703bf36866c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -117,6 +117,8 @@ def extract_high_v8i16 :
     ComplexPattern<v4i16, 1, "SelectExtractHigh", [extract_subvector, bitconvert]>;
 def extract_high_v4i32 :
     ComplexPattern<v2i32, 1, "SelectExtractHigh", [extract_subvector, bitconvert]>;
+def extract_high_v2i64 :
+    ComplexPattern<v1i64, 1, "SelectExtractHigh", [extract_subvector, bitconvert]>;
 
 def extract_high_dup_v8i16 :
    BinOpFrag<(extract_subvector (v8i16 (AArch64duplane16 (v8i16 node:$LHS), node:$RHS)), (i64 4))>;
@@ -6502,24 +6504,27 @@ multiclass SIMDNarrowThreeVectorBHS<bit U, bits<4> opc, string asm,
 }
 
 multiclass SIMDDifferentThreeVectorBD<bit U, bits<4> opc, string asm,
-                                      Intrinsic IntOp> {
+                                      SDPatternOperator OpNode> {
   def v8i8   : BaseSIMDDifferentThreeVector<U, 0b000, opc,
                                             V128, V64, V64,
                                             asm, ".8h", ".8b", ".8b",
-      [(set (v8i16 V128:$Rd), (IntOp (v8i8 V64:$Rn), (v8i8 V64:$Rm)))]>;
+      [(set (v8i16 V128:$Rd), (OpNode (v8i8 V64:$Rn), (v8i8 V64:$Rm)))]>;
   def v16i8  : BaseSIMDDifferentThreeVector<U, 0b001, opc,
                                             V128, V128, V128,
                                             asm#"2", ".8h", ".16b", ".16b", []>;
   let Predicates = [HasAES] in {
     def v1i64  : BaseSIMDDifferentThreeVector<U, 0b110, opc,
                                               V128, V64, V64,
-                                              asm, ".1q", ".1d", ".1d", []>;
+                                              asm, ".1q", ".1d", ".1d",
+        [(set (v16i8 V128:$Rd), (OpNode (v1i64 V64:$Rn), (v1i64 V64:$Rm)))]>;
     def v2i64  : BaseSIMDDifferentThreeVector<U, 0b111, opc,
                                               V128, V128, V128,
-                                              asm#"2", ".1q", ".2d", ".2d", []>;
+                                              asm#"2", ".1q", ".2d", ".2d",
+        [(set (v16i8 V128:$Rd), (OpNode (extract_high_v2i64 (v2i64 V128:$Rn)),
+                                        (extract_high_v2i64 (v2i64 V128:$Rm))))]>;
   }
 
-  def : Pat<(v8i16 (IntOp (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))),
+  def : Pat<(v8i16 (OpNode (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn))),
                           (v8i8 (extract_high_v16i8 (v16i8 V128:$Rm))))),
       (!cast<Instruction>(NAME#"v16i8") V128:$Rn, V128:$Rm)>;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 4ea6e39f40b3..58c23a1a813f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -671,6 +671,8 @@ def AArch64NvCast : SDNode<"AArch64ISD::NVCAST", SDTUnaryOp>;
 
 def SDT_AArch64mull : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>,
                                     SDTCisSameAs<1, 2>]>;
+def AArch64pmull    : SDNode<"AArch64ISD::PMULL", SDT_AArch64mull,
+                             [SDNPCommutative]>;
 def AArch64smull    : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull,
                              [SDNPCommutative]>;
 def AArch64umull    : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull,
@@ -5226,7 +5228,7 @@ defm ADDHN  : SIMDNarrowThreeVectorBHS<0,0b0100,"addhn", int_aarch64_neon_addhn>
 defm SUBHN  : SIMDNarrowThreeVectorBHS<0,0b0110,"subhn", int_aarch64_neon_subhn>;
 defm RADDHN : SIMDNarrowThreeVectorBHS<1,0b0100,"raddhn",int_aarch64_neon_raddhn>;
 defm RSUBHN : SIMDNarrowThreeVectorBHS<1,0b0110,"rsubhn",int_aarch64_neon_rsubhn>;
-defm PMULL  : SIMDDifferentThreeVectorBD<0,0b1110,"pmull",int_aarch64_neon_pmull>;
+defm PMULL  : SIMDDifferentThreeVectorBD<0,0b1110,"pmull", AArch64pmull>;
 defm SABAL  : SIMDLongThreeVectorTiedBHSabal<0,0b0101,"sabal",
                                              AArch64sabd>;
 defm SABDL   : SIMDLongThreeVectorBHSabdl<0, 0b0111, "sabdl",
@@ -5304,13 +5306,6 @@ defm : Neon_mul_acc_widen_patterns<sub, AArch64umull,
 defm : Neon_mul_acc_widen_patterns<sub, AArch64smull,
      SMLSLv8i8_v8i16, SMLSLv4i16_v4i32, SMLSLv2i32_v2i64>;
 
-// Patterns for 64-bit pmull
-def : Pat<(int_aarch64_neon_pmull64 V64:$Rn, V64:$Rm),
-          (PMULLv1i64 V64:$Rn, V64:$Rm)>;
-def : Pat<(int_aarch64_neon_pmull64 (extractelt (v2i64 V128:$Rn), (i64 1)),
-                                    (extractelt (v2i64 V128:$Rm), (i64 1))),
-          (PMULLv2i64 V128:$Rn, V128:$Rm)>;
-
 // CodeGen patterns for addhn and subhn instructions, which can actually be
 // written in LLVM IR without too much 
diff iculty.
 

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-pmull2.ll b/llvm/test/CodeGen/AArch64/aarch64-pmull2.ll
index 86112fd934b6..e28041edbd89 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-pmull2.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-pmull2.ll
@@ -8,23 +8,19 @@
 define void @test1(ptr %0, ptr %1) {
 ; CHECK-LABEL: test1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldp q0, q1, [x1]
-; CHECK-NEXT:    mov w8, #56824
 ; CHECK-NEXT:    mov w9, #61186
-; CHECK-NEXT:    movk w8, #40522, lsl #16
+; CHECK-NEXT:    mov w8, #56824
 ; CHECK-NEXT:    movk w9, #29710, lsl #16
-; CHECK-NEXT:    mov x10, v0.d[1]
-; CHECK-NEXT:    fmov d2, x9
-; CHECK-NEXT:    mov x11, v1.d[1]
-; CHECK-NEXT:    fmov d3, x8
-; CHECK-NEXT:    fmov d4, x10
-; CHECK-NEXT:    pmull v0.1q, v0.1d, v2.1d
-; CHECK-NEXT:    fmov d5, x11
-; CHECK-NEXT:    pmull v1.1q, v1.1d, v2.1d
-; CHECK-NEXT:    pmull v2.1q, v4.1d, v3.1d
-; CHECK-NEXT:    pmull v3.1q, v5.1d, v3.1d
-; CHECK-NEXT:    eor v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    eor v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    movk w8, #40522, lsl #16
+; CHECK-NEXT:    ldp q0, q1, [x1]
+; CHECK-NEXT:    fmov d3, x9
+; CHECK-NEXT:    dup v2.2d, x8
+; CHECK-NEXT:    pmull2 v4.1q, v0.2d, v2.2d
+; CHECK-NEXT:    pmull v0.1q, v0.1d, v3.1d
+; CHECK-NEXT:    pmull2 v2.1q, v1.2d, v2.2d
+; CHECK-NEXT:    pmull v1.1q, v1.1d, v3.1d
+; CHECK-NEXT:    eor v0.16b, v0.16b, v4.16b
+; CHECK-NEXT:    eor v1.16b, v1.16b, v2.16b
 ; CHECK-NEXT:    stp q0, q1, [x1]
 ; CHECK-NEXT:    ret
   %3 = load <2 x i64>, ptr %1
@@ -53,9 +49,8 @@ define void @test1(ptr %0, ptr %1) {
 define void @test2(ptr %0, <2 x i64> %1, <2 x i64> %2) {
 ; CHECK-LABEL: test2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov x8, v0.d[1]
-; CHECK-NEXT:    fmov d0, x8
-; CHECK-NEXT:    pmull v0.1q, v0.1d, v1.1d
+; CHECK-NEXT:    dup v1.2d, v1.d[0]
+; CHECK-NEXT:    pmull2 v0.1q, v0.2d, v1.2d
 ; CHECK-NEXT:    str q0, [x0]
 ; CHECK-NEXT:    ret
   %4 = extractelement <2 x i64> %1, i64 1

diff  --git a/llvm/test/CodeGen/AArch64/pmull-ldr-merge.ll b/llvm/test/CodeGen/AArch64/pmull-ldr-merge.ll
index a8127c300047..8e8f0c1d21ff 100644
--- a/llvm/test/CodeGen/AArch64/pmull-ldr-merge.ll
+++ b/llvm/test/CodeGen/AArch64/pmull-ldr-merge.ll
@@ -28,11 +28,10 @@ define void @test1(ptr %0, i64 %1, i64 %2) {
 define void @test2(ptr %0, i64 %1, i64 %2, <2 x i64> %3) {
 ; CHECK-LABEL: test2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov x9, v0.d[1]
 ; CHECK-NEXT:    add x8, x0, x1, lsl #4
-; CHECK-NEXT:    ldr d0, [x8, #8]
-; CHECK-NEXT:    fmov d1, x9
-; CHECK-NEXT:    pmull v0.1q, v1.1d, v0.1d
+; CHECK-NEXT:    add x9, x8, #8
+; CHECK-NEXT:    ld1r { v1.2d }, [x9]
+; CHECK-NEXT:    pmull2 v0.1q, v0.2d, v1.2d
 ; CHECK-NEXT:    str q0, [x8]
 ; CHECK-NEXT:    ret
   %5 = getelementptr inbounds <2 x i64>, ptr %0, i64 %1
@@ -68,10 +67,8 @@ define void @test3(ptr %0, i64 %1, i64 %2, i64 %3) {
 define void @test4(ptr %0, <2 x i64> %1, i64 %2) {
 ; CHECK-LABEL: test4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov x8, v0.d[1]
-; CHECK-NEXT:    fmov d0, x1
-; CHECK-NEXT:    fmov d1, x8
-; CHECK-NEXT:    pmull v0.1q, v1.1d, v0.1d
+; CHECK-NEXT:    dup v1.2d, x1
+; CHECK-NEXT:    pmull2 v0.1q, v0.2d, v1.2d
 ; CHECK-NEXT:    str q0, [x0]
 ; CHECK-NEXT:    ret
   %4 = extractelement <2 x i64> %1, i64 1


        


More information about the llvm-commits mailing list