[clang] [llvm] [ARM] Fix operand order for MVE predicated VFMAS (PR #115908)
Oliver Stannard via cfe-commits
cfe-commits at lists.llvm.org
Tue Nov 12 09:20:34 PST 2024
https://github.com/ostannard created https://github.com/llvm/llvm-project/pull/115908
For most MVE predicated FMA instructions, disabled lanes will contain the value in the addend operand. However, The VFMAS instruction takes the addend in a GPR, and the output register is shared with the first multiply operand, so disabled lanes will get that value instead. This means that we can't use the same intrinsic as for the other VFMA instructions. Instead, we can codegen the vfmas intrinsic to a regular FMA and select in clang, which the backend already has the patterns to select VFMAS from.
>From b12203bd46e8e3d0372fe41287fa429b5877b314 Mon Sep 17 00:00:00 2001
From: Oliver Stannard <oliver.stannard at arm.com>
Date: Tue, 12 Nov 2024 15:50:11 +0000
Subject: [PATCH] [ARM] Fix operand order for MVE predicated VFMAS
For most MVE predicated FMA instructions, disabled lanes will contain
the value in the addend operand. However, The VFMAS instruction takes
the addend in a GPR, and the output register is shared with the first
multiply operand, so disabled lanes will get that value instead. This
means that we can't use the same intrinsic as for the other VFMA
instructions. Instead, we can codegen the vfmas intrinsic to a regular
FMA and select in clang, which the backend already has the patterns to
select VFMAS from.
---
clang/include/clang/Basic/arm_mve.td | 2 +-
.../test/CodeGen/arm-mve-intrinsics/ternary.c | 18 +++----
llvm/include/llvm/IR/IntrinsicsARM.td | 1 +
llvm/lib/Target/ARM/ARMInstrMVE.td | 2 -
.../CodeGen/Thumb2/mve-intrinsics/ternary.ll | 47 ++++++++++++++++++-
llvm/test/CodeGen/Thumb2/mve-qrintr.ll | 3 +-
6 files changed, 59 insertions(+), 14 deletions(-)
diff --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td
index 52185ca07da41f..1debb94a0a7b81 100644
--- a/clang/include/clang/Basic/arm_mve.td
+++ b/clang/include/clang/Basic/arm_mve.td
@@ -193,7 +193,7 @@ multiclass FMA<bit add> {
def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
unpromoted<Scalar>:$addend_s,
Predicate:$pred),
- (seq (splat $addend_s):$addend, pred_cg)>;
+ (select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>;
}
}
diff --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
index 36b2ce063cb188..768d397cb5611f 100644
--- a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
+++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
@@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr
// CHECK-LABEL: @test_vfmasq_m_n_f16(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
-// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
-// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]])
-// CHECK-NEXT: ret <8 x half> [[TMP2]]
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
+// CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]]
+// CHECK-NEXT: ret <8 x half> [[TMP3]]
//
float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
#ifdef POLYMORPHIC
@@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p
// CHECK-LABEL: @test_vfmasq_m_n_f32(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
-// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
-// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
-// CHECK-NEXT: ret <4 x float> [[TMP2]]
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
+// CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]]
+// CHECK-NEXT: ret <4 x float> [[TMP3]]
//
float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
#ifdef POLYMORPHIC
diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index 11b9877091a8ed..b18d3fcc9e3f44 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
+// fma_predicated returns the add operand for disabled lanes.
def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 22af599f4f0859..bdd0d739a05684 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
(VTI.Vec (fma v1, v2, vs)),
v1)),
(VTI.Vec (Inst v1, v2, is, ARMVCCThen, $pred, zero_reg))>;
- def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
- (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>;
} else {
def : Pat<(VTI.Vec (fma v1, vs, v2)),
(VTI.Vec (Inst v2, v1, is))>;
diff --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
index 4b1b8998ce73eb..265452c18f81a4 100644
--- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
@@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half>
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov r1, s8
; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vdup.16 q2, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vfmast.f16 q0, q1, r1
+; CHECK-NEXT: vfmat.f16 q2, q0, q1
+; CHECK-NEXT: vmov q0, q2
; CHECK-NEXT: bx lr
entry:
%0 = bitcast float %c.coerce to i32
@@ -476,13 +478,36 @@ entry:
ret <8 x half> %4
}
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f16_select:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov r1, s8
+; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vfmast.f16 q0, q1, r1
+; CHECK-NEXT: bx lr
+entry:
+ %0 = bitcast float %c.coerce to i32
+ %tmp.0.extract.trunc = trunc i32 %0 to i16
+ %1 = bitcast i16 %tmp.0.extract.trunc to half
+ %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+ %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+ %2 = zext i16 %p to i32
+ %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+ %4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
+ %5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a
+ ret <8 x half> %5
+}
+
define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
; CHECK-LABEL: test_vfmasq_m_n_f32:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov r1, s8
; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vdup.32 q2, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vfmast.f32 q0, q1, r1
+; CHECK-NEXT: vfmat.f32 q2, q0, q1
+; CHECK-NEXT: vmov q0, q2
; CHECK-NEXT: bx lr
entry:
%.splatinsert = insertelement <4 x float> undef, float %c, i32 0
@@ -493,6 +518,24 @@ entry:
ret <4 x float> %2
}
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f32_select:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov r1, s8
+; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vfmast.f32 q0, q1, r1
+; CHECK-NEXT: bx lr
+entry:
+ %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+ %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+ %0 = zext i16 %p to i32
+ %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+ %2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
+ %3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a
+ ret <4 x float> %3
+}
+
define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
; CHECK-LABEL: test_vfmsq_m_f16:
; CHECK: @ %bb.0: @ %entry
diff --git a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
index 151e51fcf0c93d..87cce804356dee 100644
--- a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
@@ -1536,7 +1536,8 @@ while.body: ; preds = %while.body.lr.ph, %
%0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013)
%1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
%2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
- %3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0)
+ %3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat)
+ %4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1
tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0)
%add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4
%sub = add nsw i32 %N.addr.013, -4
More information about the cfe-commits
mailing list