[clang] aba5580 - [ARM] Fix operand order for MVE predicated VFMAS (#115908)
via cfe-commits
cfe-commits at lists.llvm.org
Wed Nov 13 04:16:33 PST 2024
Author: Oliver Stannard
Date: 2024-11-13T12:16:28Z
New Revision: aba55809e9af5e0d981f10c7f9b44a1f57b423c2
URL: https://github.com/llvm/llvm-project/commit/aba55809e9af5e0d981f10c7f9b44a1f57b423c2
DIFF: https://github.com/llvm/llvm-project/commit/aba55809e9af5e0d981f10c7f9b44a1f57b423c2.diff
LOG: [ARM] Fix operand order for MVE predicated VFMAS (#115908)
For most MVE predicated FMA instructions, disabled lanes will contain
the value in the addend operand. However, The VFMAS instruction takes
the addend in a GPR, and the output register is shared with the first
multiply operand, so disabled lanes will get that value instead. This
means that we can't use the same intrinsic as for the other VFMA
instructions. Instead, we can codegen the vfmas intrinsic to a regular
FMA and select in clang, which the backend already has the patterns to
select VFMAS from.
Added:
Modified:
clang/include/clang/Basic/arm_mve.td
clang/test/CodeGen/arm-mve-intrinsics/ternary.c
llvm/include/llvm/IR/IntrinsicsARM.td
llvm/lib/Target/ARM/ARMInstrMVE.td
llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
llvm/test/CodeGen/Thumb2/mve-qrintr.ll
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td
index 52185ca07da41f..1debb94a0a7b81 100644
--- a/clang/include/clang/Basic/arm_mve.td
+++ b/clang/include/clang/Basic/arm_mve.td
@@ -193,7 +193,7 @@ multiclass FMA<bit add> {
def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
unpromoted<Scalar>:$addend_s,
Predicate:$pred),
- (seq (splat $addend_s):$addend, pred_cg)>;
+ (select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>;
}
}
diff --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
index 36b2ce063cb188..768d397cb5611f 100644
--- a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
+++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
@@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr
// CHECK-LABEL: @test_vfmasq_m_n_f16(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
-// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
-// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]])
-// CHECK-NEXT: ret <8 x half> [[TMP2]]
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
+// CHECK-NEXT: [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]]
+// CHECK-NEXT: ret <8 x half> [[TMP3]]
//
float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
#ifdef POLYMORPHIC
@@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p
// CHECK-LABEL: @test_vfmasq_m_n_f32(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
-// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
-// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
-// CHECK-NEXT: ret <4 x float> [[TMP2]]
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
+// CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]]
+// CHECK-NEXT: ret <4 x float> [[TMP3]]
//
float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
#ifdef POLYMORPHIC
diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index 11b9877091a8ed..b18d3fcc9e3f44 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
+// fma_predicated returns the add operand for disabled lanes.
def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 22af599f4f0859..bdd0d739a05684 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
(VTI.Vec (fma v1, v2, vs)),
v1)),
(VTI.Vec (Inst v1, v2, is, ARMVCCThen, $pred, zero_reg))>;
- def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
- (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>;
} else {
def : Pat<(VTI.Vec (fma v1, vs, v2)),
(VTI.Vec (Inst v2, v1, is))>;
diff --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
index 4b1b8998ce73eb..265452c18f81a4 100644
--- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
@@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half>
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov r1, s8
; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vdup.16 q2, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vfmast.f16 q0, q1, r1
+; CHECK-NEXT: vfmat.f16 q2, q0, q1
+; CHECK-NEXT: vmov q0, q2
; CHECK-NEXT: bx lr
entry:
%0 = bitcast float %c.coerce to i32
@@ -476,13 +478,36 @@ entry:
ret <8 x half> %4
}
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f16_select:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov r1, s8
+; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vfmast.f16 q0, q1, r1
+; CHECK-NEXT: bx lr
+entry:
+ %0 = bitcast float %c.coerce to i32
+ %tmp.0.extract.trunc = trunc i32 %0 to i16
+ %1 = bitcast i16 %tmp.0.extract.trunc to half
+ %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+ %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+ %2 = zext i16 %p to i32
+ %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+ %4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
+ %5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a
+ ret <8 x half> %5
+}
+
define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
; CHECK-LABEL: test_vfmasq_m_n_f32:
; CHECK: @ %bb.0: @ %entry
; CHECK-NEXT: vmov r1, s8
; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vdup.32 q2, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vfmast.f32 q0, q1, r1
+; CHECK-NEXT: vfmat.f32 q2, q0, q1
+; CHECK-NEXT: vmov q0, q2
; CHECK-NEXT: bx lr
entry:
%.splatinsert = insertelement <4 x float> undef, float %c, i32 0
@@ -493,6 +518,24 @@ entry:
ret <4 x float> %2
}
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f32_select:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: vmov r1, s8
+; CHECK-NEXT: vmsr p0, r0
+; CHECK-NEXT: vpst
+; CHECK-NEXT: vfmast.f32 q0, q1, r1
+; CHECK-NEXT: bx lr
+entry:
+ %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+ %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+ %0 = zext i16 %p to i32
+ %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+ %2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
+ %3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a
+ ret <4 x float> %3
+}
+
define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
; CHECK-LABEL: test_vfmsq_m_f16:
; CHECK: @ %bb.0: @ %entry
diff --git a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
index 151e51fcf0c93d..87cce804356dee 100644
--- a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
@@ -1536,7 +1536,8 @@ while.body: ; preds = %while.body.lr.ph, %
%0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013)
%1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
%2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
- %3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0)
+ %3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat)
+ %4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1
tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0)
%add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4
%sub = add nsw i32 %N.addr.013, -4
More information about the cfe-commits
mailing list