[llvm] 4a59694 - [AArch64][SVE] Combine FADD and FMUL aarch64 intrinsics to FMLA

Matt Devereau via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 8 04:23:03 PST 2021


Author: Matt
Date: 2021-11-08T12:22:38Z
New Revision: 4a59694ba148ec551bb4ad85cf3fcabe4ddaeaa8

URL: https://github.com/llvm/llvm-project/commit/4a59694ba148ec551bb4ad85cf3fcabe4ddaeaa8
DIFF: https://github.com/llvm/llvm-project/commit/4a59694ba148ec551bb4ad85cf3fcabe4ddaeaa8.diff

LOG: [AArch64][SVE] Combine FADD and FMUL aarch64 intrinsics to FMLA

This is a refinement to the work in
https://reviews.llvm.org/D111638

Fold (fadd p a (fmul p b c)) into (fma p a b c)

Differential Revision: https://reviews.llvm.org/D113095

Added: 
    llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll

Modified: 
    llvm/include/llvm/IR/Operator.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index 0370e51ff85c..b83d83f0d0ab 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -247,6 +247,9 @@ class FastMathFlags {
   void operator|=(const FastMathFlags &OtherFlags) {
     Flags |= OtherFlags.Flags;
   }
+  bool operator!=(const FastMathFlags &OtherFlags) const {
+    return Flags != OtherFlags.Flags;
+  }
 };
 
 /// Utility class for floating point operations which can have

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ae4b7e596a0b..5182f4d20408 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -695,6 +695,36 @@ static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
   return None;
 }
 
+static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
+                                                        IntrinsicInst &II) {
+  // fold (fadd p a (fmul p b c)) -> (fma p a b c)
+  Value *P = II.getOperand(0);
+  Value *A = II.getOperand(1);
+  auto FMul = II.getOperand(2);
+  Value *B, *C;
+  if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
+                       m_Specific(P), m_Value(B), m_Value(C))))
+    return None;
+
+  if (!FMul->hasOneUse())
+    return None;
+
+  llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
+  // Stop the combine when the flags on the inputs 
diff er in case dropping flags
+  // would lead to us missing out on more beneficial optimizations.
+  if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
+    return None;
+  if (!FAddFlags.allowContract())
+    return None;
+
+  IRBuilder<> Builder(II.getContext());
+  Builder.SetInsertPoint(&II);
+  auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
+                                      {II.getType()}, {P, A, B, C}, &II);
+  FMLA->setFastMathFlags(FAddFlags);
+  return IC.replaceInstUsesWith(II, FMLA);
+}
+
 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
   switch (Intrinsic) {
   case Intrinsic::aarch64_sve_fmul:
@@ -724,6 +754,13 @@ static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
   return IC.replaceInstUsesWith(II, BinOp);
 }
 
+static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
+                                                        IntrinsicInst &II) {
+  if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
+    return FMLA;
+  return instCombineSVEVectorBinOp(IC, II);
+}
+
 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
                                                        IntrinsicInst &II) {
   auto *OpPredicate = II.getOperand(0);
@@ -969,6 +1006,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
   case Intrinsic::aarch64_sve_fmul:
     return instCombineSVEVectorMul(IC, II);
   case Intrinsic::aarch64_sve_fadd:
+    return instCombineSVEVectorFAdd(IC, II);
   case Intrinsic::aarch64_sve_fsub:
     return instCombineSVEVectorBinOp(IC, II);
   case Intrinsic::aarch64_sve_tbl:

diff  --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll
new file mode 100644
index 000000000000..56219c2d9a9b
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladd.ll
@@ -0,0 +1,108 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -instcombine < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define dso_local <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @combine_fmla(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %3
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_mul_first_operand(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_mul_first_operand(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP2]], <vscale x 8 x half> [[A:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP3]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %2, <vscale x 8 x half> %a)
+  ret <vscale x 8 x half> %3
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_contract_flag_only(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_contract_flag_only(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP2]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %3
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_no_flags(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_no_flags(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP3]]
+;
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %3
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_pred(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_neq_pred(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP3]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP4]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP5]]
+;
+; ret <vscale x 8 x half> %9
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 5)
+  %3 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
+  %4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %5 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %3, <vscale x 8 x half> %a, <vscale x 8 x half> %4)
+  ret <vscale x 8 x half> %5
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_two_fmul_uses(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_two_fmul_uses(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[TMP3]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP4]]
+;
+; ret <vscale x 8 x half> %8
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  %4 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %3, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %4
+}
+
+define dso_local <vscale x 8 x half> @neg_combine_fmla_neq_flags(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) local_unnamed_addr #0 {
+; CHECK-LABEL: @neg_combine_fmla_neq_flags(
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[TMP2]])
+; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP3]]
+;
+; ret <vscale x 8 x half> %7
+  %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
+  %2 = tail call reassoc nnan contract <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
+  %3 = tail call reassoc contract <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %2)
+  ret <vscale x 8 x half> %3
+}
+
+declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
+declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
+declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32)
+attributes #0 = { "target-features"="+sve" }


        


More information about the llvm-commits mailing list