[llvm] 3f8e714 - [ARM,MVE] Add intrinsics and isel for MVE fused multiply-add.

Simon Tatham via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 12 04:14:28 PDT 2020


Author: Simon Tatham
Date: 2020-03-12T11:13:50Z
New Revision: 3f8e714e2f9f2dc3367d2f3fc569abfaf28f314c

URL: https://github.com/llvm/llvm-project/commit/3f8e714e2f9f2dc3367d2f3fc569abfaf28f314c
DIFF: https://github.com/llvm/llvm-project/commit/3f8e714e2f9f2dc3367d2f3fc569abfaf28f314c.diff

LOG: [ARM,MVE] Add intrinsics and isel for MVE fused multiply-add.

Summary:
This adds the ACLE intrinsic family for the VFMA and VFMS
instructions, which perform fused multiply-add on vectors of floats.

I've represented the unpredicated versions in IR using the cross-
platform `@llvm.fma` IR intrinsic. We already had isel rules to
convert one of those into a vector VFMA in the simplest possible way;
but we didn't have rules to detect a negated argument and turn it into
VFMS, or rules to detect a splat argument and turn it into one of the
two vector/scalar forms of the instruction. Now we have all of those.

The predicated form uses a target-specific intrinsic as usual, but
I've stuck to just one, for a predicated FMA. The subtraction and
splat versions are code-generated by passing an fneg or a splat as one
of its operands, the same way as the unpredicated version.

In arm_mve_defs.h, I've had to introduce a tiny extra piece of
infrastructure: a record `id` for use in codegen dags which implements
the identity function. (Just because you can't declare a Tablegen
value of type dag which is //only// a `$varname`: you have to wrap it
in something. Now I can write `(id $varname)` to get the same effect.)

Reviewers: dmgreen, MarkMurrayARM, miyuki, ostannard

Reviewed By: dmgreen

Subscribers: kristof.beyls, hiraditya, danielkiss, cfe-commits, llvm-commits

Tags: #clang, #llvm

Differential Revision: https://reviews.llvm.org/D75998

Added: 
    clang/test/CodeGen/arm-mve-intrinsics/ternary.c
    llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll

Modified: 
    clang/include/clang/Basic/arm_mve.td
    clang/include/clang/Basic/arm_mve_defs.td
    llvm/include/llvm/IR/IntrinsicsARM.td
    llvm/lib/Target/ARM/ARMInstrMVE.td
    llvm/test/CodeGen/Thumb2/mve-fmas.ll

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td
index d9a2035e8a0e..d2203d650301 100644
--- a/clang/include/clang/Basic/arm_mve.td
+++ b/clang/include/clang/Basic/arm_mve.td
@@ -162,6 +162,46 @@ let pnt = PNT_NType in {
 }
 }
 
+multiclass FMA<bit add> {
+  // FMS instructions are defined in the ArmARM as if they negate the
+  // second multiply input.
+  defvar m2_cg = !if(add, (id $m2), (fneg $m2));
+
+  defvar unpred_cg = (IRIntBase<"fma", [Vector]> $m1, m2_cg, $addend);
+  defvar pred_cg   = (IRInt<"fma_predicated", [Vector, Predicate]>
+                          $m1, m2_cg, $addend, $pred);
+
+  def q: Intrinsic<Vector, (args Vector:$addend, Vector:$m1, Vector:$m2),
+                   unpred_cg>;
+
+  def q_m: Intrinsic<Vector, (args Vector:$addend, Vector:$m1, Vector:$m2,
+                                   Predicate:$pred), pred_cg>;
+
+  // Only FMA has the vector/scalar variants, not FMS
+  if add then let pnt = PNT_NType in {
+
+    def q_n: Intrinsic<Vector, (args Vector:$addend, Vector:$m1,
+                                     unpromoted<Scalar>:$m2_s),
+                     (seq (splat $m2_s):$m2, unpred_cg)>;
+    def sq_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
+                                      unpromoted<Scalar>:$addend_s),
+                        (seq (splat $addend_s):$addend, unpred_cg)>;
+    def q_m_n: Intrinsic<Vector, (args Vector:$addend, Vector:$m1,
+                                       unpromoted<Scalar>:$m2_s,
+                                       Predicate:$pred),
+                         (seq (splat $m2_s):$m2, pred_cg)>;
+    def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
+                                        unpromoted<Scalar>:$addend_s,
+                                        Predicate:$pred),
+                          (seq (splat $addend_s):$addend, pred_cg)>;
+  }
+}
+
+let params = T.Float in {
+  defm vfma: FMA<1>;
+  defm vfms: FMA<0>;
+}
+
 let params = !listconcat(T.Int16, T.Int32) in {
   let pnt = PNT_None in {
     def vmvnq_n: Intrinsic<Vector, (args imm_simd_vmvn:$imm),

diff  --git a/clang/include/clang/Basic/arm_mve_defs.td b/clang/include/clang/Basic/arm_mve_defs.td
index f1424f2ea594..4038a18027f8 100644
--- a/clang/include/clang/Basic/arm_mve_defs.td
+++ b/clang/include/clang/Basic/arm_mve_defs.td
@@ -133,6 +133,18 @@ def unzip: CGHelperFn<"VectorUnzip"> {
 }
 def zip: CGHelperFn<"VectorZip">;
 
+// Trivial 'codegen' function that just returns its argument. Useful
+// for wrapping up a variable name like $foo into a thing you can pass
+// around as type 'dag'.
+def id: IRBuilderBase {
+  // All the other cases of IRBuilderBase use 'prefix' to specify a function
+  // call, including the open parenthesis. MveEmitter puts the closing paren on
+  // the end. So if we _just_ specify an open paren with no function name
+  // before it, then the generated C++ code will simply wrap the input value in
+  // parentheses, returning it unchanged.
+  let prefix = "(";
+}
+
 // Helper for making boolean flags in IR
 def i1: IRBuilderBase {
   let prefix = "llvm::ConstantInt::get(Builder.getInt1Ty(), ";

diff  --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
new file mode 100644
index 000000000000..ab1cb14c3aed
--- /dev/null
+++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
@@ -0,0 +1,261 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
+// RUN: %clang_cc1 -triple thumbv8.1m.main-arm-none-eabi -target-feature +mve.fp -mfloat-abi hard -fallow-half-arguments-and-returns -O0 -disable-O0-optnone -S -emit-llvm -o - %s | opt -S -sroa | FileCheck %s
+// RUN: %clang_cc1 -triple thumbv8.1m.main-arm-none-eabi -target-feature +mve.fp -mfloat-abi hard -fallow-half-arguments-and-returns -O0 -disable-O0-optnone -DPOLYMORPHIC -S -emit-llvm -o - %s | opt -S -sroa | FileCheck %s
+
+#include <arm_mve.h>
+
+// CHECK-LABEL: @test_vfmaq_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[B:%.*]], <8 x half> [[C:%.*]], <8 x half> [[A:%.*]])
+// CHECK-NEXT:    ret <8 x half> [[TMP0]]
+//
+float16x8_t test_vfmaq_f16(float16x8_t a, float16x8_t b, float16x8_t c) {
+#ifdef POLYMORPHIC
+  return vfmaq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmaq_f16(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[B:%.*]], <4 x float> [[C:%.*]], <4 x float> [[A:%.*]])
+// CHECK-NEXT:    ret <4 x float> [[TMP0]]
+//
+float32x4_t test_vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c) {
+#ifdef POLYMORPHIC
+  return vfmaq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmaq_f32(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_n_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[C_COERCE:%.*]] to i32
+// CHECK-NEXT:    [[TMP_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[TMP0]] to i16
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16 [[TMP_0_EXTRACT_TRUNC]] to half
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> undef, half [[TMP1]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x half> [[A:%.*]])
+// CHECK-NEXT:    ret <8 x half> [[TMP2]]
+//
+float16x8_t test_vfmaq_n_f16(float16x8_t a, float16x8_t b, float16_t c) {
+#ifdef POLYMORPHIC
+  return vfmaq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmaq_n_f16(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_n_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> undef, float [[C:%.*]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x float> [[A:%.*]])
+// CHECK-NEXT:    ret <4 x float> [[TMP0]]
+//
+float32x4_t test_vfmaq_n_f32(float32x4_t a, float32x4_t b, float32_t c) {
+#ifdef POLYMORPHIC
+  return vfmaq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmaq_n_f32(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmasq_n_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[C_COERCE:%.*]] to i32
+// CHECK-NEXT:    [[TMP_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[TMP0]] to i16
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16 [[TMP_0_EXTRACT_TRUNC]] to half
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> undef, half [[TMP1]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
+// CHECK-NEXT:    ret <8 x half> [[TMP2]]
+//
+float16x8_t test_vfmasq_n_f16(float16x8_t a, float16x8_t b, float16_t c) {
+#ifdef POLYMORPHIC
+  return vfmasq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmasq_n_f16(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmasq_n_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> undef, float [[C:%.*]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
+// CHECK-NEXT:    ret <4 x float> [[TMP0]]
+//
+float32x4_t test_vfmasq_n_f32(float32x4_t a, float32x4_t b, float32_t c) {
+#ifdef POLYMORPHIC
+  return vfmasq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmasq_n_f32(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmsq_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = fneg <8 x half> [[C:%.*]]
+// CHECK-NEXT:    [[TMP1:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[B:%.*]], <8 x half> [[TMP0]], <8 x half> [[A:%.*]])
+// CHECK-NEXT:    ret <8 x half> [[TMP1]]
+//
+float16x8_t test_vfmsq_f16(float16x8_t a, float16x8_t b, float16x8_t c) {
+#ifdef POLYMORPHIC
+  return vfmsq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmsq_f16(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmsq_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = fneg <4 x float> [[C:%.*]]
+// CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[B:%.*]], <4 x float> [[TMP0]], <4 x float> [[A:%.*]])
+// CHECK-NEXT:    ret <4 x float> [[TMP1]]
+//
+float32x4_t test_vfmsq_f32(float32x4_t a, float32x4_t b, float32x4_t c) {
+#ifdef POLYMORPHIC
+  return vfmsq(a, b, c);
+#else /* POLYMORPHIC */
+  return vfmsq_f32(a, b, c);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_m_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[B:%.*]], <8 x half> [[C:%.*]], <8 x half> [[A:%.*]], <8 x i1> [[TMP1]])
+// CHECK-NEXT:    ret <8 x half> [[TMP2]]
+//
+float16x8_t test_vfmaq_m_f16(float16x8_t a, float16x8_t b, float16x8_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmaq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmaq_m_f16(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_m_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[B:%.*]], <4 x float> [[C:%.*]], <4 x float> [[A:%.*]], <4 x i1> [[TMP1]])
+// CHECK-NEXT:    ret <4 x float> [[TMP2]]
+//
+float32x4_t test_vfmaq_m_f32(float32x4_t a, float32x4_t b, float32x4_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmaq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmaq_m_f32(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_m_n_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[C_COERCE:%.*]] to i32
+// CHECK-NEXT:    [[TMP_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[TMP0]] to i16
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16 [[TMP_0_EXTRACT_TRUNC]] to half
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> undef, half [[TMP1]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP3:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP2]])
+// CHECK-NEXT:    [[TMP4:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x half> [[A:%.*]], <8 x i1> [[TMP3]])
+// CHECK-NEXT:    ret <8 x half> [[TMP4]]
+//
+float16x8_t test_vfmaq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmaq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmaq_m_n_f16(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmaq_m_n_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> undef, float [[C:%.*]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x float> [[A:%.*]], <4 x i1> [[TMP1]])
+// CHECK-NEXT:    ret <4 x float> [[TMP2]]
+//
+float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmaq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmaq_m_n_f32(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmasq_m_n_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = bitcast float [[C_COERCE:%.*]] to i32
+// CHECK-NEXT:    [[TMP_0_EXTRACT_TRUNC:%.*]] = trunc i32 [[TMP0]] to i16
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16 [[TMP_0_EXTRACT_TRUNC]] to half
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> undef, half [[TMP1]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> undef, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP3:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP2]])
+// CHECK-NEXT:    [[TMP4:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP3]])
+// CHECK-NEXT:    ret <8 x half> [[TMP4]]
+//
+float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmasq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmasq_m_n_f16(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmasq_m_n_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> undef, float [[C:%.*]], i32 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> undef, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
+// CHECK-NEXT:    ret <4 x float> [[TMP2]]
+//
+float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmasq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmasq_m_n_f32(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmsq_m_f16(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = fneg <8 x half> [[C:%.*]]
+// CHECK-NEXT:    [[TMP1:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP1]])
+// CHECK-NEXT:    [[TMP3:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[B:%.*]], <8 x half> [[TMP0]], <8 x half> [[A:%.*]], <8 x i1> [[TMP2]])
+// CHECK-NEXT:    ret <8 x half> [[TMP3]]
+//
+float16x8_t test_vfmsq_m_f16(float16x8_t a, float16x8_t b, float16x8_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmsq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmsq_m_f16(a, b, c, p);
+#endif /* POLYMORPHIC */
+}
+
+// CHECK-LABEL: @test_vfmsq_m_f32(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = fneg <4 x float> [[C:%.*]]
+// CHECK-NEXT:    [[TMP1:%.*]] = zext i16 [[P:%.*]] to i32
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP1]])
+// CHECK-NEXT:    [[TMP3:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[B:%.*]], <4 x float> [[TMP0]], <4 x float> [[A:%.*]], <4 x i1> [[TMP2]])
+// CHECK-NEXT:    ret <4 x float> [[TMP3]]
+//
+float32x4_t test_vfmsq_m_f32(float32x4_t a, float32x4_t b, float32x4_t c, mve_pred16_t p) {
+#ifdef POLYMORPHIC
+  return vfmsq_m(a, b, c, p);
+#else /* POLYMORPHIC */
+  return vfmsq_m_f32(a, b, c, p);
+#endif /* POLYMORPHIC */
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index 4968689a4b56..b35e1b7c4473 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -1243,6 +1243,10 @@ def int_arm_mve_vqmovn_predicated: Intrinsic<[llvm_anyvector_ty],
     llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
     llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
 
+def int_arm_mve_fma_predicated: Intrinsic<[llvm_anyvector_ty],
+   [LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
+    LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
+
 // CDE (Custom Datapath Extension)
 
 def int_arm_cde_cx1: Intrinsic<

diff  --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 7600daae03c4..20856ccdb411 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -3420,27 +3420,37 @@ class MVE_VADDSUBFMA_fp<string iname, string suffix, bit size, bit bit_4,
   let validForTailPredication = 1;
 }
 
-def MVE_VFMAf32 : MVE_VADDSUBFMA_fp<"vfma", "f32", 0b0, 0b1, 0b0, 0b0,
-    (ins MQPR:$Qd_src), vpred_n, "$Qd = $Qd_src">;
-def MVE_VFMAf16 : MVE_VADDSUBFMA_fp<"vfma", "f16", 0b1, 0b1, 0b0, 0b0,
-    (ins MQPR:$Qd_src), vpred_n, "$Qd = $Qd_src">;
-
-def MVE_VFMSf32 : MVE_VADDSUBFMA_fp<"vfms", "f32", 0b0, 0b1, 0b0, 0b1,
-    (ins MQPR:$Qd_src), vpred_n, "$Qd = $Qd_src">;
-def MVE_VFMSf16 : MVE_VADDSUBFMA_fp<"vfms", "f16", 0b1, 0b1, 0b0, 0b1,
-    (ins MQPR:$Qd_src), vpred_n, "$Qd = $Qd_src">;
+multiclass MVE_VFMA_fp_multi<string iname, bit fms, MVEVectorVTInfo VTI> {
+  def "" : MVE_VADDSUBFMA_fp<iname, VTI.Suffix, VTI.Size{0}, 0b1, 0b0, fms,
+                             (ins MQPR:$Qd_src), vpred_n, "$Qd = $Qd_src">;
+  defvar Inst = !cast<Instruction>(NAME);
+  defvar pred_int = int_arm_mve_fma_predicated;
+  defvar m1   = (VTI.Vec MQPR:$m1);
+  defvar m2   = (VTI.Vec MQPR:$m2);
+  defvar add  = (VTI.Vec MQPR:$add);
+  defvar pred = (VTI.Pred VCCR:$pred);
 
-let Predicates = [HasMVEFloat] in {
-  def : Pat<(v8f16 (fma (v8f16 MQPR:$src1), (v8f16 MQPR:$src2), (v8f16 MQPR:$src3))),
-            (v8f16 (MVE_VFMAf16 $src3, $src1, $src2))>;
-  def : Pat<(v4f32 (fma (v4f32 MQPR:$src1), (v4f32 MQPR:$src2), (v4f32 MQPR:$src3))),
-            (v4f32 (MVE_VFMAf32 $src3, $src1, $src2))>;
-  def : Pat<(v8f16 (fma (fneg (v8f16 MQPR:$src1)), (v8f16 MQPR:$src2), (v8f16 MQPR:$src3))),
-            (v8f16 (MVE_VFMSf16 $src3, $src1, $src2))>;
-  def : Pat<(v4f32 (fma (fneg (v4f32 MQPR:$src1)), (v4f32 MQPR:$src2), (v4f32 MQPR:$src3))),
-            (v4f32 (MVE_VFMSf32 $src3, $src1, $src2))>;
+  let Predicates = [HasMVEFloat] in {
+    if fms then {
+      def : Pat<(VTI.Vec (fma (fneg m1), m2, add)), (Inst $add, $m1, $m2)>;
+      def : Pat<(VTI.Vec (fma m1, (fneg m2), add)), (Inst $add, $m1, $m2)>;
+      def : Pat<(VTI.Vec (pred_int (fneg m1), m2, add, pred)),
+                (Inst $add, $m1, $m2, ARMVCCThen, $pred)>;
+      def : Pat<(VTI.Vec (pred_int m1, (fneg m2), add, pred)),
+                (Inst $add, $m1, $m2, ARMVCCThen, $pred)>;
+    } else {
+      def : Pat<(VTI.Vec (fma m1, m2, add)), (Inst $add, $m1, $m2)>;
+      def : Pat<(VTI.Vec (pred_int m1, m2, add, pred)),
+                (Inst $add, $m1, $m2, ARMVCCThen, $pred)>;
+    }
+  }
 }
 
+defm MVE_VFMAf32 : MVE_VFMA_fp_multi<"vfma", 0, MVE_v4f32>;
+defm MVE_VFMAf16 : MVE_VFMA_fp_multi<"vfma", 0, MVE_v8f16>;
+defm MVE_VFMSf32 : MVE_VFMA_fp_multi<"vfms", 1, MVE_v4f32>;
+defm MVE_VFMSf16 : MVE_VFMA_fp_multi<"vfms", 1, MVE_v8f16>;
+
 multiclass MVE_VADDSUB_fp_m<string iname, bit bit_21, MVEVectorVTInfo VTI,
                             SDNode unpred_op, Intrinsic pred_int> {
   def "" : MVE_VADDSUBFMA_fp<iname, VTI.Suffix, VTI.Size{0}, 0, 1, bit_21> {
@@ -5184,11 +5194,39 @@ let Predicates = [HasMVEInt] in {
             (v16i8 (MVE_VMLAS_qr_u8 $src1, $src2, $x))>;
 }
 
+multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
+                             bit scalar_addend> {
+  def "": MVE_VFMAMLA_qr<iname, VTI.Suffix, VTI.Size{0}, 0b11, scalar_addend>;
+  defvar Inst = !cast<Instruction>(NAME);
+  defvar pred_int = int_arm_mve_fma_predicated;
+  defvar v1   = (VTI.Vec MQPR:$v1);
+  defvar v2   = (VTI.Vec MQPR:$v2);
+  defvar s    = !if(VTI.Size{0}, (f16 HPR:$s), (f32 SPR:$s));
+  defvar vs   = (VTI.Vec (ARMvdup s));
+  defvar is   = (i32 (COPY_TO_REGCLASS s, rGPR));
+  defvar pred = (VTI.Pred VCCR:$pred);
+
+  let Predicates = [HasMVEFloat] in {
+    if scalar_addend then {
+      def : Pat<(VTI.Vec (fma v1, v2, vs)), (VTI.Vec (Inst v1, v2, is))>;
+      def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
+                (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred))>;
+    } else {
+      def : Pat<(VTI.Vec (fma v1, vs, v2)), (VTI.Vec (Inst v2, v1, is))>;
+      def : Pat<(VTI.Vec (fma vs, v1, v2)), (VTI.Vec (Inst v2, v1, is))>;
+      def : Pat<(VTI.Vec (pred_int v1, vs, v2, pred)),
+                (VTI.Vec (Inst v2, v1, is, ARMVCCThen, pred))>;
+      def : Pat<(VTI.Vec (pred_int vs, v1, v2, pred)),
+                (VTI.Vec (Inst v2, v1, is, ARMVCCThen, pred))>;
+    }
+  }
+}
+
 let Predicates = [HasMVEFloat] in {
-  def MVE_VFMA_qr_f16  : MVE_VFMAMLA_qr<"vfma",  "f16", 0b1, 0b11, 0b0>;
-  def MVE_VFMA_qr_f32  : MVE_VFMAMLA_qr<"vfma",  "f32", 0b0, 0b11, 0b0>;
-  def MVE_VFMA_qr_Sf16 : MVE_VFMAMLA_qr<"vfmas", "f16", 0b1, 0b11, 0b1>;
-  def MVE_VFMA_qr_Sf32 : MVE_VFMAMLA_qr<"vfmas", "f32", 0b0, 0b11, 0b1>;
+  defm MVE_VFMA_qr_f16  : MVE_VFMA_qr_multi<"vfma",  MVE_v8f16, 0>;
+  defm MVE_VFMA_qr_f32  : MVE_VFMA_qr_multi<"vfma",  MVE_v4f32, 0>;
+  defm MVE_VFMA_qr_Sf16 : MVE_VFMA_qr_multi<"vfmas", MVE_v8f16, 1>;
+  defm MVE_VFMA_qr_Sf32 : MVE_VFMA_qr_multi<"vfmas", MVE_v4f32, 1>;
 }
 
 class MVE_VQDMLAH_qr<string iname, string suffix, bit U, bits<2> size,

diff  --git a/llvm/test/CodeGen/Thumb2/mve-fmas.ll b/llvm/test/CodeGen/Thumb2/mve-fmas.ll
index 76a00545063e..a65b663b3311 100644
--- a/llvm/test/CodeGen/Thumb2/mve-fmas.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-fmas.ll
@@ -208,8 +208,7 @@ define arm_aapcs_vfpcc <8 x half> @vfmar16(<8 x half> %src1, <8 x half> %src2, f
 ; CHECK-MVE-VMLA:       @ %bb.0: @ %entry
 ; CHECK-MVE-VMLA-NEXT:    vcvtb.f16.f32 s8, s8
 ; CHECK-MVE-VMLA-NEXT:    vmov r0, s8
-; CHECK-MVE-VMLA-NEXT:    vdup.16 q2, r0
-; CHECK-MVE-VMLA-NEXT:    vfma.f16 q0, q1, q2
+; CHECK-MVE-VMLA-NEXT:    vfma.f16 q0, q1, r0
 ; CHECK-MVE-VMLA-NEXT:    bx lr
 ;
 ; CHECK-MVE-LABEL: vfmar16:
@@ -275,9 +274,7 @@ define arm_aapcs_vfpcc <8 x half> @vfma16(<8 x half> %src1, <8 x half> %src2, fl
 ; CHECK-MVE-VMLA:       @ %bb.0: @ %entry
 ; CHECK-MVE-VMLA-NEXT:    vcvtb.f16.f32 s8, s8
 ; CHECK-MVE-VMLA-NEXT:    vmov r0, s8
-; CHECK-MVE-VMLA-NEXT:    vdup.16 q2, r0
-; CHECK-MVE-VMLA-NEXT:    vfma.f16 q2, q0, q1
-; CHECK-MVE-VMLA-NEXT:    vmov q0, q2
+; CHECK-MVE-VMLA-NEXT:    vfmas.f16 q0, q1, r0
 ; CHECK-MVE-VMLA-NEXT:    bx lr
 ;
 ; CHECK-MVE-LABEL: vfma16:
@@ -419,8 +416,7 @@ define arm_aapcs_vfpcc <4 x float> @vfmar32(<4 x float> %src1, <4 x float> %src2
 ; CHECK-MVE-VMLA-LABEL: vfmar32:
 ; CHECK-MVE-VMLA:       @ %bb.0: @ %entry
 ; CHECK-MVE-VMLA-NEXT:    vmov r0, s8
-; CHECK-MVE-VMLA-NEXT:    vdup.32 q2, r0
-; CHECK-MVE-VMLA-NEXT:    vfma.f32 q0, q1, q2
+; CHECK-MVE-VMLA-NEXT:    vfma.f32 q0, q1, r0
 ; CHECK-MVE-VMLA-NEXT:    bx lr
 ;
 ; CHECK-MVE-LABEL: vfmar32:
@@ -449,9 +445,7 @@ define arm_aapcs_vfpcc <4 x float> @vfmas32(<4 x float> %src1, <4 x float> %src2
 ; CHECK-MVE-VMLA-LABEL: vfmas32:
 ; CHECK-MVE-VMLA:       @ %bb.0: @ %entry
 ; CHECK-MVE-VMLA-NEXT:    vmov r0, s8
-; CHECK-MVE-VMLA-NEXT:    vdup.32 q2, r0
-; CHECK-MVE-VMLA-NEXT:    vfma.f32 q2, q0, q1
-; CHECK-MVE-VMLA-NEXT:    vmov q0, q2
+; CHECK-MVE-VMLA-NEXT:    vfmas.f32 q0, q1, r0
 ; CHECK-MVE-VMLA-NEXT:    bx lr
 ;
 ; CHECK-MVE-LABEL: vfmas32:

diff  --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
new file mode 100644
index 000000000000..fcdb29f7e2f2
--- /dev/null
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
@@ -0,0 +1,242 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=thumbv8.1m.main -mattr=+mve.fp -verify-machineinstrs -o - %s | FileCheck %s
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmaq_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_vfmaq_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vfma.f16 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %b, <8 x half> %c, <8 x half> %a)
+  ret <8 x half> %0
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmaq_f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: test_vfmaq_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vfma.f32 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %b, <4 x float> %c, <4 x float> %a)
+  ret <4 x float> %0
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmaq_n_f16(<8 x half> %a, <8 x half> %b, float %c.coerce) {
+; CHECK-LABEL: test_vfmaq_n_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vfma.f16 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %b, <8 x half> %.splat, <8 x half> %a)
+  ret <8 x half> %2
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmaq_n_f32(<4 x float> %a, <4 x float> %b, float %c) {
+; CHECK-LABEL: test_vfmaq_n_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vfma.f32 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %b, <4 x float> %.splat, <4 x float> %a)
+  ret <4 x float> %0
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_n_f16(<8 x half> %a, <8 x half> %b, float %c.coerce) {
+; CHECK-LABEL: test_vfmasq_n_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vfmas.f16 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
+  ret <8 x half> %2
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_n_f32(<4 x float> %a, <4 x float> %b, float %c) {
+; CHECK-LABEL: test_vfmasq_n_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vfmas.f32 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
+  ret <4 x float> %0
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmsq_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: test_vfmsq_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vfms.f16 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = fneg <8 x half> %c
+  %1 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %b, <8 x half> %0, <8 x half> %a)
+  ret <8 x half> %1
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmsq_f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: test_vfmsq_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vfms.f32 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = fneg <4 x float> %c
+  %1 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %b, <4 x float> %0, <4 x float> %a)
+  ret <4 x float> %1
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmaq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmaq_m_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmat.f16 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = zext i16 %p to i32
+  %1 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %0)
+  %2 = tail call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> %b, <8 x half> %c, <8 x half> %a, <8 x i1> %1)
+  ret <8 x half> %2
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmaq_m_f32(<4 x float> %a, <4 x float> %b, <4 x float> %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmaq_m_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmat.f32 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = zext i16 %p to i32
+  %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+  %2 = tail call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %b, <4 x float> %c, <4 x float> %a, <4 x i1> %1)
+  ret <4 x float> %2
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmaq_m_n_f16(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmaq_m_n_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmat.f16 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = zext i16 %p to i32
+  %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+  %4 = tail call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> %b, <8 x half> %.splat, <8 x half> %a, <8 x i1> %3)
+  ret <8 x half> %4
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmaq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmaq_m_n_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmat.f32 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = zext i16 %p to i32
+  %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+  %2 = tail call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %b, <4 x float> %.splat, <4 x float> %a, <4 x i1> %1)
+  ret <4 x float> %2
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f16 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = zext i16 %p to i32
+  %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+  %4 = tail call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> %a, <8 x half> %b, <8 x half> %.splat, <8 x i1> %3)
+  ret <8 x half> %4
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vmov r0, s8
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f32 q0, q1, r0
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = zext i16 %p to i32
+  %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+  %2 = tail call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %a, <4 x float> %b, <4 x float> %.splat, <4 x i1> %1)
+  ret <4 x float> %2
+}
+
+define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmsq_m_f16:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmst.f16 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = fneg <8 x half> %c
+  %1 = zext i16 %p to i32
+  %2 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %1)
+  %3 = tail call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> %b, <8 x half> %0, <8 x half> %a, <8 x i1> %2)
+  ret <8 x half> %3
+}
+
+define arm_aapcs_vfpcc <4 x float> @test_vfmsq_m_f32(<4 x float> %a, <4 x float> %b, <4 x float> %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmsq_m_f32:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmst.f32 q0, q1, q2
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = fneg <4 x float> %c
+  %1 = zext i16 %p to i32
+  %2 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %1)
+  %3 = tail call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %b, <4 x float> %0, <4 x float> %a, <4 x i1> %2)
+  ret <4 x float> %3
+}
+
+declare <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32)
+declare <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32)
+
+declare <8 x half> @llvm.fma.v8f16(<8 x half>, <8 x half>, <8 x half>)
+declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
+declare <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half>, <8 x half>, <8 x half>, <8 x i1>)
+declare <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float>, <4 x float>, <4 x float>, <4 x i1>)


        


More information about the llvm-commits mailing list