[llvm] fc28a2f - [AArch64][SVE] Combine predicated FMUL/FADD into FMA

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 27 04:43:45 PDT 2021


Author: Matt
Date: 2021-10-27T11:41:23Z
New Revision: fc28a2f8ced4ff3565f2c47fead824a017c25d21

URL: https://github.com/llvm/llvm-project/commit/fc28a2f8ced4ff3565f2c47fead824a017c25d21
DIFF: https://github.com/llvm/llvm-project/commit/fc28a2f8ced4ff3565f2c47fead824a017c25d21.diff

LOG: [AArch64][SVE] Combine predicated FMUL/FADD into FMA

Combine FADD and FMUL intrinsics into FMA when the result of the FMUL is an FADD operand
with one only use and both use the same predicate.

Differential Revision: https://reviews.llvm.org/D111638

Added: 
    llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll

Modified: 
    llvm/include/llvm/IR/Operator.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index d0bce742cc96d..d3a89f94a9eca 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -243,6 +243,9 @@ class FastMathFlags {
   void operator|=(const FastMathFlags &OtherFlags) {
     Flags |= OtherFlags.Flags;
   }
+  bool operator!=(const FastMathFlags &OtherFlags) const {
+    return Flags != OtherFlags.Flags;
+  }
 };
 
 /// Utility class for floating point operations which can have

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 9d7434377b1d1..8718ac2382238 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -695,6 +695,45 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
   return None;
 }
 
+static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
+                                                        IntrinsicInst &II) {
+  // fold (fadd p a (fmul p b c)) -> (fma p a b c)
+  Value *p, *FMul, *a, *b, *c;
+  auto m_SVEFAdd = [](auto p, auto w, auto x) {
+    return m_CombineOr(m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, w, x),
+                       m_Intrinsic<Intrinsic::aarch64_sve_fadd>(p, x, w));
+  };
+  auto m_SVEFMul = [](auto p, auto y, auto z) {
+    return m_Intrinsic<Intrinsic::aarch64_sve_fmul>(p, y, z);
+  };
+  if (!match(&II, m_SVEFAdd(m_Value(p), m_Value(a),
+                            m_CombineAnd(m_Value(FMul),
+                                         m_SVEFMul(m_Deferred(p), m_Value(b),
+                                                   m_Value(c))))))
+    return None;
+
+  if (!FMul->hasOneUse())
+    return None;
+
+  llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
+  llvm::FastMathFlags FMulFlags = cast<CallInst>(FMul)->getFastMathFlags();
+  // Don't combine when FMul & Fadd flags 
diff er to prevent the loss of any
+  // additional important flags
+  if (FAddFlags != FMulFlags)
+    return None;
+  bool AllowReassoc = FAddFlags.allowReassoc() && FMulFlags.allowReassoc();
+  bool AllowContract = FAddFlags.allowContract() && FMulFlags.allowContract();
+  if (!AllowReassoc || !AllowContract)
+    return None;
+
+  IRBuilder<> Builder(II.getContext());
+  Builder.SetInsertPoint(&II);
+  auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
+                                      {II.getType()}, {p, a, b, c}, &II);
+  FMLA->setFastMathFlags(FAddFlags);
+  return IC.replaceInstUsesWith(II, FMLA);
+}
+
 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
   switch (Intrinsic) {
   case Intrinsic::aarch64_sve_fmul:
@@ -724,6 +763,14 @@ static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
   return IC.replaceInstUsesWith(II, BinOp);
 }
 
+static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
+                                                        IntrinsicInst &II) {
+  auto FMLA = instCombineSVEVectorFMLA(IC, II);
+  if (FMLA)
+    return FMLA;
+  return instCombineSVEVectorBinOp(IC, II);
+}
+
 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
                                                        IntrinsicInst &II) {
   auto *OpPredicate = II.getOperand(0);
@@ -901,6 +948,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
   case Intrinsic::aarch64_sve_fmul:
     return instCombineSVEVectorMul(IC, II);
   case Intrinsic::aarch64_sve_fadd:
+    return instCombineSVEVectorFAdd(IC, II);
   case Intrinsic::aarch64_sve_fsub:
     return instCombineSVEVectorBinOp(IC, II);
   case Intrinsic::aarch64_sve_tbl:

diff  --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll
new file mode 100644
index 0000000000000..13d0f3a55c5ea
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-fmla.ll
@@ -0,0 +1,121 @@
+; RUN: opt -S -instcombine < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fmla
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  ret <vscale x 8 x half> %6
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_contract_flag_only
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; CHECK-NEXT:  ret <vscale x 8 x half> %7
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_reassoc_flag_only(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_reassoc_flag_only
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; CHECK-NEXT:  ret <vscale x 8 x half> %7
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call reassoc <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_min_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_min_flags
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  ret <vscale x 8 x half> %6
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fast_flag(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_no_fast_flag
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; CHECK-NEXT:  ret <vscale x 8 x half> %7
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_no_fmul(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_no_fmul
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; CHECK-NEXT:  ret <vscale x 8 x half> %7
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_neq_pred
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
+; CHECK-NEXT:  %7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
+; CHECK-NEXT:  %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
+; ret <vscale x 8 x half> %9
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
+  %7 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %6)
+  %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %9 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %7, <vscale x 8 x half> %1, <vscale x 8 x half> %8)
+  ret <vscale x 8 x half> %9
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; CHECK-NEXT:  %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
+; ret <vscale x 8 x half> %8
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  %8 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %7, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %8
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x i1> %0, <vscale x 8 x half> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %3) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_neq_flags
+; CHECK-NEXT:  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+; CHECK-NEXT:  %6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+; CHECK-NEXT:  %7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+; ret <vscale x 8 x half> %7
+  %5 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %0)
+  %6 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %2, <vscale x 8 x half> %3)
+  %7 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %5, <vscale x 8 x half> %1, <vscale x 8 x half> %6)
+  ret <vscale x 8 x half> %7
+}
+
+declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
+declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
+declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
+attributes #0 = { "target-features"="+sve" }


        


More information about the llvm-commits mailing list