[llvm] a107cf0 - [AArch64][InstCombine] Fuse ADD+MUL and SUB+MUL AArch64 instrinsics

Matt Devereau via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 22 01:56:41 PST 2022


Author: Matt Devereau
Date: 2022-12-22T09:46:40Z
New Revision: a107cf0c407f418fafce60539b42114fe68fd737

URL: https://github.com/llvm/llvm-project/commit/a107cf0c407f418fafce60539b42114fe68fd737
DIFF: https://github.com/llvm/llvm-project/commit/a107cf0c407f418fafce60539b42114fe68fd737.diff

LOG: [AArch64][InstCombine] Fuse ADD+MUL and SUB+MUL AArch64 instrinsics

Fold (ADD p c (MUL p a b)) into (MAD p a b c)
Fold (FADD p c (FMUL p a b)) into (FMAD p a b c)
Fold (FSUB p c (FMUL p a b)) into (FNMSB p a b c)

Fold (ADD p (MUL p a b) c) into (MLA p c a b)
Fold (FADD p (FMUL p a b) c) into (FMLA p c a b)
Fold (SUB p (MUL p a b) C) into (MLS p c a b)
Fold (FSUB p (FMUL p a b) c) into (FMLS p c a b)

Differential Revision: https://reviews.llvm.org/D140200

Added: 
    llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Removed: 
    llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ea2d1eb6f9405..ae12ae951d756 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1045,34 +1045,51 @@ static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
   return std::nullopt;
 }
 
+template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
 static std::optional<Instruction *>
-instCombineSVEVectorFMLA(InstCombiner &IC, IntrinsicInst &II) {
-  // fold (fadd p a (fmul p b c)) -> (fma p a b c)
+instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
+                                  bool MergeIntoAddendOp) {
   Value *P = II.getOperand(0);
-  Value *A = II.getOperand(1);
-  auto FMul = II.getOperand(2);
-  Value *B, *C;
-  if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
-                       m_Specific(P), m_Value(B), m_Value(C))))
-    return std::nullopt;
+  Value *MulOp0, *MulOp1, *AddendOp, *Mul;
+  if (MergeIntoAddendOp) {
+    AddendOp = II.getOperand(1);
+    Mul = II.getOperand(2);
+  } else {
+    AddendOp = II.getOperand(2);
+    Mul = II.getOperand(1);
+  }
 
-  if (!FMul->hasOneUse())
+  if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
+                                      m_Value(MulOp1))))
     return std::nullopt;
 
-  llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
-  // Stop the combine when the flags on the inputs 
diff er in case dropping flags
-  // would lead to us missing out on more beneficial optimizations.
-  if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
-    return std::nullopt;
-  if (!FAddFlags.allowContract())
+  if (!Mul->hasOneUse())
     return std::nullopt;
 
+  Instruction *FMFSource = nullptr;
+  if (II.getType()->isFPOrFPVectorTy()) {
+    llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
+    // Stop the combine when the flags on the inputs 
diff er in case dropping
+    // flags would lead to us missing out on more beneficial optimizations.
+    if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
+      return std::nullopt;
+    if (!FAddFlags.allowContract())
+      return std::nullopt;
+    FMFSource = &II;
+  }
+
   IRBuilder<> Builder(II.getContext());
   Builder.SetInsertPoint(&II);
-  auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
-                                      {II.getType()}, {P, A, B, C}, &II);
-  FMLA->setFastMathFlags(FAddFlags);
-  return IC.replaceInstUsesWith(II, FMLA);
+
+  CallInst *Res;
+  if (MergeIntoAddendOp)
+    Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()},
+                                  {P, AddendOp, MulOp0, MulOp1}, FMFSource);
+  else
+    Res = Builder.CreateIntrinsic(FuseOpc, {II.getType()},
+                                  {P, MulOp0, MulOp1, AddendOp}, FMFSource);
+
+  return IC.replaceInstUsesWith(II, Res);
 }
 
 static bool isAllActivePredicate(Value *Pred) {
@@ -1166,10 +1183,45 @@ instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
   return IC.replaceInstUsesWith(II, BinOp);
 }
 
-static std::optional<Instruction *>
-instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
-  if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
+static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
+                                                            IntrinsicInst &II) {
+  if (auto FMLA =
+          instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
+                                            Intrinsic::aarch64_sve_fmla>(IC, II,
+                                                                         true))
     return FMLA;
+  if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
+                                                   Intrinsic::aarch64_sve_mla>(
+          IC, II, true))
+    return MLA;
+  if (auto FMAD =
+          instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
+                                            Intrinsic::aarch64_sve_fmad>(IC, II,
+                                                                         false))
+    return FMAD;
+  if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
+                                                   Intrinsic::aarch64_sve_mad>(
+          IC, II, false))
+    return MAD;
+  return instCombineSVEVectorBinOp(IC, II);
+}
+
+static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
+                                                            IntrinsicInst &II) {
+  if (auto FMLS =
+          instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
+                                            Intrinsic::aarch64_sve_fmls>(IC, II,
+                                                                         true))
+    return FMLS;
+  if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
+                                                   Intrinsic::aarch64_sve_mls>(
+          IC, II, true))
+    return MLS;
+  if (auto FMSB =
+          instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
+                                            Intrinsic::aarch64_sve_fnmsb>(
+              IC, II, false))
+    return FMSB;
   return instCombineSVEVectorBinOp(IC, II);
 }
 
@@ -1470,9 +1522,11 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
   case Intrinsic::aarch64_sve_fmul:
     return instCombineSVEVectorMul(IC, II);
   case Intrinsic::aarch64_sve_fadd:
-    return instCombineSVEVectorFAdd(IC, II);
+  case Intrinsic::aarch64_sve_add:
+    return instCombineSVEVectorAdd(IC, II);
   case Intrinsic::aarch64_sve_fsub:
-    return instCombineSVEVectorBinOp(IC, II);
+  case Intrinsic::aarch64_sve_sub:
+    return instCombineSVEVectorSub(IC, II);
   case Intrinsic::aarch64_sve_tbl:
     return instCombineSVETBL(IC, II);
   case Intrinsic::aarch64_sve_uunpkhi:

diff  --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll
similarity index 50%
rename from llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll
rename to llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll
index 026080b576339..1352eaa798b68 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll
@@ -6,50 +6,115 @@ target triple = "aarch64-unknown-linux-gnu"
 define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
 ; CHECK-LABEL: @combine_fmla(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
 ;
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
-  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
   ret <vscale x 8 x half> %3
 }
 
-define dso_local <vscale x 8 x half> @neg_combine_fmla_mul_first_operand(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
-; CHECK-LABEL: @neg_combine_fmla_mul_first_operand(
+define dso_local <vscale x 16 x i8> @combine_mla_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_mla_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 16 x i8> @llvm.aarch64.sve.mla.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[C:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]])
+; CHECK-NEXT:    ret <vscale x 16 x i8> [[TMP1]]
+;
+  %1 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  %2 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.add.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %c, <vscale x 16 x i8> %1)
+  ret <vscale x 16 x i8> %2
+}
+
+define dso_local <vscale x 8 x half> @combine_fmad(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fmad(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP2]], <vscale x 8 x half> [[A:%.*]])
-; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP3]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmad.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
 ;
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
-  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %a)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %c)
   ret <vscale x 8 x half> %3
 }
 
-define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
-; CHECK-LABEL: @neg_combine_fmla_contract_flag_only(
+define dso_local <vscale x 16 x i8> @combine_mad_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_mad_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 16 x i8> @llvm.aarch64.sve.mad.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]], <vscale x 16 x i8> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 16 x i8> [[TMP1]]
+;
+  %1 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  %2 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.add.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %1, <vscale x 16 x i8> %c)
+  ret <vscale x 16 x i8> %2
+}
+
+define dso_local <vscale x 8 x half> @combine_fmls(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fmls(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = call contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmls.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
 ;
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
-  %2 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %3 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fsub.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %3
+}
+
+define dso_local <vscale x 16 x i8> @combine_mls_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_mls_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 16 x i8> @llvm.aarch64.sve.mls.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[C:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]])
+; CHECK-NEXT:    ret <vscale x 16 x i8> [[TMP1]]
+;
+  %1 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  %2 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.sub.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %c, <vscale x 16 x i8> %1)
+  ret <vscale x 16 x i8> %2
+}
+
+define dso_local <vscale x 8 x half> @combine_fnmsb(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fnmsb(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fnmsb.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fsub.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %c)
+  ret <vscale x 8 x half> %3
+}
+
+; No integer variant of fnmsb exists; Do not combine
+define dso_local <vscale x 16 x i8> @neg_combine_nmsb_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_nmsb_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.sub.nxv16i8(<vscale x 16 x i1> [[P]], <vscale x 16 x i8> [[TMP1]], <vscale x 16 x i8> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 16 x i8> [[TMP2]]
+;
+  %1 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  %2 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.sub.nxv16i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %1, <vscale x 16 x i8> %c)
+  ret <vscale x 16 x i8> %2
+}
+
+define dso_local <vscale x 8 x half> @combine_fmla_contract_flag_only(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fmla_contract_flag_only(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
   ret <vscale x 8 x half> %3
 }
 
 define dso_local <vscale x 8 x half> @neg_combine_fmla_no_flags(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
 ; CHECK-LABEL: @neg_combine_fmla_no_flags(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[TMP2]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP3]]
 ;
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
-  %2 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %3 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  %2 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
   ret <vscale x 8 x half> %3
 }
 
@@ -58,31 +123,31 @@ define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
 ; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
 ; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP2]])
-; CHECK-NEXT:    [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP3]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP4]])
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP3]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[TMP4]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP5]]
 ;
 ; ret <vscale x 8 x half> %9
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
   %2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
   %3 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
-  %4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %5 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %3, <vscale x 8 x half> %a, <vscale x 8 x half> %4)
+  %4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %5 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %3, <vscale x 8 x half> %c, <vscale x 8 x half> %4)
   ret <vscale x 8 x half> %5
 }
 
 define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
 ; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses(
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
-; CHECK-NEXT:    [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[TMP2]])
 ; CHECK-NEXT:    [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP3]], <vscale x 8 x half> [[TMP2]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP4]]
 ;
 ; ret <vscale x 8 x half> %8
   %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
-  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
-  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
   %4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %3, <vscale x 8 x half> %2)
   ret <vscale x 8 x half> %4
 }
@@ -101,8 +166,15 @@ define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x
   ret <vscale x 8 x half> %3
 }
 
+declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
 declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
+declare <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1>)
+declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
 declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
 declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
-declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
+declare <vscale x 8 x half> @llvm.aarch64.sve.fsub.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
+declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 16 x i8> @llvm.aarch64.sve.add.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 16 x i8> @llvm.aarch64.sve.sub.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+
 attributes #0 = { "target-features"="+sve" }


        


More information about the llvm-commits mailing list