[llvm] d916856 - [AArch64] Allow strict opcodes in faddp patterns

John Brawn via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 17 05:12:34 PST 2022


Author: John Brawn
Date: 2022-02-17T13:11:55Z
New Revision: d916856bee1165aa78ca342cdd43523c33333736

URL: https://github.com/llvm/llvm-project/commit/d916856bee1165aa78ca342cdd43523c33333736
DIFF: https://github.com/llvm/llvm-project/commit/d916856bee1165aa78ca342cdd43523c33333736.diff

LOG: [AArch64] Allow strict opcodes in faddp patterns

This also requires adjustment to code in AArch64ISelLowering so that
vector_extract is distributed over strict_fadd.

Differential Revision: https://reviews.llvm.org/D118489

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td
    llvm/test/CodeGen/AArch64/faddp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d06fd2b27341..6e763202ce91 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14228,6 +14228,7 @@ static SDValue performANDCombine(SDNode *N,
 
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
   switch (Opcode) {
+  case ISD::STRICT_FADD:
   case ISD::FADD:
     return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
   case ISD::ADD:
@@ -14244,6 +14245,7 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   const bool FullFP16 =
       static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
+  bool IsStrict = N0->isStrictFPOpcode();
 
   // Rewrite for pairwise fadd pattern
   //   (f32 (extract_vector_elt
@@ -14252,11 +14254,14 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
   // ->
   //   (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
   //              (extract_vector_elt (vXf32 Other) 1))
+  // For strict_fadd we need to make sure the old strict_fadd can be deleted, so
+  // we can only do this when it's used only by the extract_vector_elt.
   if (ConstantN1 && ConstantN1->getZExtValue() == 0 &&
-      hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) {
+      hasPairwiseAdd(N0->getOpcode(), VT, FullFP16) &&
+      (!IsStrict || N0.hasOneUse())) {
     SDLoc DL(N0);
-    SDValue N00 = N0->getOperand(0);
-    SDValue N01 = N0->getOperand(1);
+    SDValue N00 = N0->getOperand(IsStrict ? 1 : 0);
+    SDValue N01 = N0->getOperand(IsStrict ? 2 : 1);
 
     ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
     SDValue Other = N00;
@@ -14269,11 +14274,23 @@ static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
 
     if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
         Other == Shuffle->getOperand(0)) {
-      return DAG.getNode(N0->getOpcode(), DL, VT,
-                         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
-                                     DAG.getConstant(0, DL, MVT::i64)),
-                         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
-                                     DAG.getConstant(1, DL, MVT::i64)));
+      SDValue Extract1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+                                     DAG.getConstant(0, DL, MVT::i64));
+      SDValue Extract2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+                                     DAG.getConstant(1, DL, MVT::i64));
+      if (!IsStrict)
+        return DAG.getNode(N0->getOpcode(), DL, VT, Extract1, Extract2);
+
+      // For strict_fadd we need uses of the final extract_vector to be replaced
+      // with the strict_fadd, but we also need uses of the chain output of the
+      // original strict_fadd to use the chain output of the new strict_fadd as
+      // otherwise it may not be deleted.
+      SDValue Ret = DAG.getNode(N0->getOpcode(), DL,
+                                {VT, MVT::Other},
+                                {N0->getOperand(0), Extract1, Extract2});
+      DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Ret);
+      DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Ret.getValue(1));
+      return SDValue(N, 0);
     }
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 53a06c2b9e8e..664f670d741c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -8100,17 +8100,17 @@ defm : InsertSubvectorUndef<i64>;
 def : Pat<(i64 (add (vector_extract (v2i64 FPR128:$Rn), (i64 0)),
                     (vector_extract (v2i64 FPR128:$Rn), (i64 1)))),
            (i64 (ADDPv2i64p (v2i64 FPR128:$Rn)))>;
-def : Pat<(f64 (fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)),
-                     (vector_extract (v2f64 FPR128:$Rn), (i64 1)))),
+def : Pat<(f64 (any_fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)),
+                         (vector_extract (v2f64 FPR128:$Rn), (i64 1)))),
            (f64 (FADDPv2i64p (v2f64 FPR128:$Rn)))>;
     // vector_extract on 64-bit vectors gets promoted to a 128 bit vector,
     // so we match on v4f32 here, not v2f32. This will also catch adding
     // the low two lanes of a true v4f32 vector.
-def : Pat<(fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)),
-                (vector_extract (v4f32 FPR128:$Rn), (i64 1))),
+def : Pat<(any_fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)),
+                    (vector_extract (v4f32 FPR128:$Rn), (i64 1))),
           (f32 (FADDPv2i32p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
-def : Pat<(fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)),
-                (vector_extract (v8f16 FPR128:$Rn), (i64 1))),
+def : Pat<(any_fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)),
+                    (vector_extract (v8f16 FPR128:$Rn), (i64 1))),
           (f16 (FADDPv2i16p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>;
 
 // Scalar 64-bit shifts in FPR64 registers.

diff  --git a/llvm/test/CodeGen/AArch64/faddp.ll b/llvm/test/CodeGen/AArch64/faddp.ll
index 06e976136c37..1476f7bcda5e 100644
--- a/llvm/test/CodeGen/AArch64/faddp.ll
+++ b/llvm/test/CodeGen/AArch64/faddp.ll
@@ -100,3 +100,83 @@ entry:
   %1 = extractelement <2 x i64> %0, i32 0
   ret i64 %1
 }
+
+define float @faddp_2xfloat_strict(<2 x float> %a) #0 {
+; CHECK-LABEL: faddp_2xfloat_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
+  %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <2 x float> %0, i32 0
+  ret float %1
+}
+
+define float @faddp_4xfloat_strict(<4 x float> %a) #0 {
+; CHECK-LABEL: faddp_4xfloat_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+  %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %a, <4 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <4 x float> %0, i32 0
+  ret float %1
+}
+
+define float @faddp_4xfloat_commute_strict(<4 x float> %a) #0 {
+; CHECK-LABEL: faddp_4xfloat_commute_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
+  %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %shift, <4 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <4 x float> %0, i32 0
+  ret float %1
+}
+
+define float @faddp_2xfloat_commute_strict(<2 x float> %a) #0 {
+; CHECK-LABEL: faddp_2xfloat_commute_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> <i32 1, i32 undef>
+  %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %shift, <2 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <2 x float> %0, i32 0
+  ret float %1
+}
+
+define double @faddp_2xdouble_strict(<2 x double> %a) #0 {
+; CHECK-LABEL: faddp_2xdouble_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    faddp d0, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
+  %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %a, <2 x double> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <2 x double> %0, i32 0
+  ret double %1
+}
+
+define double @faddp_2xdouble_commute_strict(<2 x double> %a) #0 {
+; CHECK-LABEL: faddp_2xdouble_commute_strict:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    faddp d0, v0.2d
+; CHECK-NEXT:    ret
+entry:
+  %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> <i32 1, i32 undef>
+  %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %shift, <2 x double> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0
+  %1 = extractelement <2 x double> %0, i32 0
+  ret double %1
+}
+
+attributes #0 = { strictfp }
+
+declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata)
+declare <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata)
+declare <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double>, <2 x double>, metadata, metadata)


        


More information about the llvm-commits mailing list