[clang] [llvm] [ARM] Fix operand order for MVE predicated VFMAS (PR #115908)

via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 12 09:20:57 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-arm

Author: Oliver Stannard (ostannard)

<details>
<summary>Changes</summary>

For most MVE predicated FMA instructions, disabled lanes will contain the value in the addend operand. However, The VFMAS instruction takes the addend in a GPR, and the output register is shared with the first multiply operand, so disabled lanes will get that value instead. This means that we can't use the same intrinsic as for the other VFMA instructions. Instead, we can codegen the vfmas intrinsic to a regular FMA and select in clang, which the backend already has the patterns to select VFMAS from.

---
Full diff: https://github.com/llvm/llvm-project/pull/115908.diff


6 Files Affected:

- (modified) clang/include/clang/Basic/arm_mve.td (+1-1) 
- (modified) clang/test/CodeGen/arm-mve-intrinsics/ternary.c (+10-8) 
- (modified) llvm/include/llvm/IR/IntrinsicsARM.td (+1) 
- (modified) llvm/lib/Target/ARM/ARMInstrMVE.td (-2) 
- (modified) llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll (+45-2) 
- (modified) llvm/test/CodeGen/Thumb2/mve-qrintr.ll (+2-1) 


``````````diff
diff --git a/clang/include/clang/Basic/arm_mve.td b/clang/include/clang/Basic/arm_mve.td
index 52185ca07da41f..1debb94a0a7b81 100644
--- a/clang/include/clang/Basic/arm_mve.td
+++ b/clang/include/clang/Basic/arm_mve.td
@@ -193,7 +193,7 @@ multiclass FMA<bit add> {
     def sq_m_n: Intrinsic<Vector, (args Vector:$m1, Vector:$m2,
                                         unpromoted<Scalar>:$addend_s,
                                         Predicate:$pred),
-                          (seq (splat $addend_s):$addend, pred_cg)>;
+                          (select $pred, (seq (splat $addend_s):$addend, unpred_cg), $m1)>;
   }
 }
 
diff --git a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
index 36b2ce063cb188..768d397cb5611f 100644
--- a/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
+++ b/clang/test/CodeGen/arm-mve-intrinsics/ternary.c
@@ -542,12 +542,13 @@ float32x4_t test_vfmaq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pr
 
 // CHECK-LABEL: @test_vfmasq_m_n_f16(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
-// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
 // CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
 // CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 [[TMP0]])
-// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.arm.mve.fma.predicated.v8f16.v8i1(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]], <8 x i1> [[TMP1]])
-// CHECK-NEXT:    ret <8 x half> [[TMP2]]
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x half> poison, half [[C:%.*]], i64 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x half> [[DOTSPLATINSERT]], <8 x half> poison, <8 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <8 x half> @llvm.fma.v8f16(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x half> [[DOTSPLAT]])
+// CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[TMP1]], <8 x half> [[TMP2]], <8 x half> [[A]]
+// CHECK-NEXT:    ret <8 x half> [[TMP3]]
 //
 float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_pred16_t p) {
 #ifdef POLYMORPHIC
@@ -559,12 +560,13 @@ float16x8_t test_vfmasq_m_n_f16(float16x8_t a, float16x8_t b, float16_t c, mve_p
 
 // CHECK-LABEL: @test_vfmasq_m_n_f32(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
-// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
 // CHECK-NEXT:    [[TMP0:%.*]] = zext i16 [[P:%.*]] to i32
 // CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 [[TMP0]])
-// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]], <4 x i1> [[TMP1]])
-// CHECK-NEXT:    ret <4 x float> [[TMP2]]
+// CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x float> poison, float [[C:%.*]], i64 0
+// CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x float> [[DOTSPLATINSERT]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP2:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A:%.*]], <4 x float> [[B:%.*]], <4 x float> [[DOTSPLAT]])
+// CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[A]]
+// CHECK-NEXT:    ret <4 x float> [[TMP3]]
 //
 float32x4_t test_vfmasq_m_n_f32(float32x4_t a, float32x4_t b, float32_t c, mve_pred16_t p) {
 #ifdef POLYMORPHIC
diff --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index 11b9877091a8ed..b18d3fcc9e3f44 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -1362,6 +1362,7 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
     llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
     llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
 
+// fma_predicated returns the add operand for disabled lanes.
 def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
    [LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
     LLVMMatchType<0> /* addend */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td
index 22af599f4f0859..bdd0d739a05684 100644
--- a/llvm/lib/Target/ARM/ARMInstrMVE.td
+++ b/llvm/lib/Target/ARM/ARMInstrMVE.td
@@ -5614,8 +5614,6 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
                                   (VTI.Vec (fma v1, v2, vs)),
                                   v1)),
                 (VTI.Vec (Inst v1, v2, is, ARMVCCThen, $pred, zero_reg))>;
-      def : Pat<(VTI.Vec (pred_int v1, v2, vs, pred)),
-                (VTI.Vec (Inst v1, v2, is, ARMVCCThen, pred, zero_reg))>;
     } else {
       def : Pat<(VTI.Vec (fma v1, vs, v2)),
                 (VTI.Vec (Inst v2, v1, is))>;
diff --git a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
index 4b1b8998ce73eb..265452c18f81a4 100644
--- a/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-intrinsics/ternary.ll
@@ -461,8 +461,10 @@ define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16(<8 x half> %a, <8 x half>
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov r1, s8
 ; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vdup.16 q2, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vfmast.f16 q0, q1, r1
+; CHECK-NEXT:    vfmat.f16 q2, q0, q1
+; CHECK-NEXT:    vmov q0, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = bitcast float %c.coerce to i32
@@ -476,13 +478,36 @@ entry:
   ret <8 x half> %4
 }
 
+define arm_aapcs_vfpcc <8 x half> @test_vfmasq_m_n_f16_select(<8 x half> %a, <8 x half> %b, float %c.coerce, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f16_select:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f16 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %0 = bitcast float %c.coerce to i32
+  %tmp.0.extract.trunc = trunc i32 %0 to i16
+  %1 = bitcast i16 %tmp.0.extract.trunc to half
+  %.splatinsert = insertelement <8 x half> undef, half %1, i32 0
+  %.splat = shufflevector <8 x half> %.splatinsert, <8 x half> undef, <8 x i32> zeroinitializer
+  %2 = zext i16 %p to i32
+  %3 = tail call <8 x i1> @llvm.arm.mve.pred.i2v.v8i1(i32 %2)
+  %4 = tail call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %.splat)
+  %5 = select <8 x i1> %3, <8 x half> %4, <8 x half> %a
+  ret <8 x half> %5
+}
+
 define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
 ; CHECK-LABEL: test_vfmasq_m_n_f32:
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov r1, s8
 ; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vdup.32 q2, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vfmast.f32 q0, q1, r1
+; CHECK-NEXT:    vfmat.f32 q2, q0, q1
+; CHECK-NEXT:    vmov q0, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
@@ -493,6 +518,24 @@ entry:
   ret <4 x float> %2
 }
 
+define arm_aapcs_vfpcc <4 x float> @test_vfmasq_m_n_f32_select(<4 x float> %a, <4 x float> %b, float %c, i16 zeroext %p) {
+; CHECK-LABEL: test_vfmasq_m_n_f32_select:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    vmov r1, s8
+; CHECK-NEXT:    vmsr p0, r0
+; CHECK-NEXT:    vpst
+; CHECK-NEXT:    vfmast.f32 q0, q1, r1
+; CHECK-NEXT:    bx lr
+entry:
+  %.splatinsert = insertelement <4 x float> undef, float %c, i32 0
+  %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
+  %0 = zext i16 %p to i32
+  %1 = tail call <4 x i1> @llvm.arm.mve.pred.i2v.v4i1(i32 %0)
+  %2 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %.splat)
+  %3 = select <4 x i1> %1, <4 x float> %2, <4 x float> %a
+  ret <4 x float> %3
+}
+
 define arm_aapcs_vfpcc <8 x half> @test_vfmsq_m_f16(<8 x half> %a, <8 x half> %b, <8 x half> %c, i16 zeroext %p) {
 ; CHECK-LABEL: test_vfmsq_m_f16:
 ; CHECK:       @ %bb.0: @ %entry
diff --git a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
index 151e51fcf0c93d..87cce804356dee 100644
--- a/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-qrintr.ll
@@ -1536,7 +1536,8 @@ while.body:                                       ; preds = %while.body.lr.ph, %
   %0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %N.addr.013)
   %1 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s1.addr.014, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
   %2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0(ptr %s2, i32 4, <4 x i1> %0, <4 x float> zeroinitializer)
-  %3 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %1, <4 x float> %2, <4 x float> %.splat, <4 x i1> %0)
+  %3 = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %1, <4 x float> %2, <4 x float> %.splat)
+  %4 = select <4 x i1> %0, <4 x float> %3, <4 x float> %1
   tail call void @llvm.masked.store.v4f32.p0(<4 x float> %3, ptr %s1.addr.014, i32 4, <4 x i1> %0)
   %add.ptr = getelementptr inbounds float, ptr %s1.addr.014, i32 4
   %sub = add nsw i32 %N.addr.013, -4

``````````

</details>


https://github.com/llvm/llvm-project/pull/115908


More information about the cfe-commits mailing list