[llvm] Fold SVE mul and mul_u to neg during isel (PR #160828)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 26 01:16:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Martin Wehking (MartinWehking)

<details>
<summary>Changes</summary>

Replace mul and mul_u ops with a neg operation if their second operand is a splat value -1.

Apply the optimization also for mul_u ops if their first operand is a splat value -1 due to their commutativity.

---
Full diff: https://github.com/llvm/llvm-project/pull/160828.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+15) 
- (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+13) 
- (added) llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll (+213) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1e30735b7a56a..27f4e8125f067 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -730,6 +730,21 @@ let Predicates = [HasSVE_or_SME] in {
   defm ABS_ZPmZ  : sve_int_un_pred_arit<  0b110, "abs",  AArch64abs_mt>;
   defm NEG_ZPmZ  : sve_int_un_pred_arit<  0b111, "neg",  AArch64neg_mt>;
 
+  // mul x (splat -1) -> neg x
+  let Predicates = [HasSVE_or_SME] in {
+    def : SVE_2_Op_Neg_One_Passthru_Pat<nxv16i8, AArch64mul_m1, nxv16i1, NEG_ZPmZ_B , i32>;
+    def : SVE_2_Op_Neg_One_Passthru_Pat<nxv8i16, AArch64mul_m1, nxv8i1, NEG_ZPmZ_H , i32>;
+    def : SVE_2_Op_Neg_One_Passthru_Pat<nxv4i32, AArch64mul_m1, nxv4i1, NEG_ZPmZ_S , i32>;
+    def : SVE_2_Op_Neg_One_Passthru_Pat<nxv2i64, AArch64mul_m1, nxv2i1, NEG_ZPmZ_D , i64>;
+
+    let AddedComplexity = 5 in {
+    defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv16i8, AArch64mul_p, nxv16i1, NEG_ZPmZ_B , i32>;
+    defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv8i16, AArch64mul_p, nxv8i1, NEG_ZPmZ_H , i32>;
+    defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv4i32, AArch64mul_p, nxv4i1, NEG_ZPmZ_S , i32>;
+    defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv2i64, AArch64mul_p, nxv2i1, NEG_ZPmZ_D , i64>;
+    }
+  }
+
   defm CLS_ZPmZ  : sve_int_un_pred_arit_bitwise<   0b000, "cls",  AArch64cls_mt>;
   defm CLZ_ZPmZ  : sve_int_un_pred_arit_bitwise<   0b001, "clz",  AArch64clz_mt>;
   defm CNT_ZPmZ  : sve_int_un_pred_arit_bitwise<   0b010, "cnt",  AArch64cnt_mt>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 9a23c35766cac..4fdc80c2d749b 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -723,6 +723,19 @@ class SVE2p1_Cvt_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out
     : Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2)),
                   (!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1))>;
 
+class SVE_2_Op_Neg_One_Passthru_Pat<ValueType vt, SDPatternOperator op, ValueType pt,
+                             Instruction inst, ValueType immT>
+: Pat<(vt (op pt:$Op1, vt:$Op2, (vt (splat_vector (immT -1))))),
+      (inst $Op2, $Op1, $Op2)>;
+
+// Same as above, but commutative
+multiclass SVE_2_Op_Neg_One_Passthru_Pat_Comm<ValueType vt, SDPatternOperator op, ValueType pt,
+                                         Instruction inst, ValueType immT> {
+def : Pat<(vt (op pt:$Op1, vt:$Op2, (vt (splat_vector (immT -1))))),
+      (inst $Op2, $Op1, $Op2)>;
+def : Pat<(vt (op pt:$Op1, (vt (splat_vector (immT -1))), vt:$Op2)),
+      (inst $Op2, $Op1, $Op2)>;
+}
 //===----------------------------------------------------------------------===//
 // SVE pattern match helpers.
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll b/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll
new file mode 100644
index 0000000000000..8db9eac027506
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll
@@ -0,0 +1,213 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -verify-machineinstrs -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; Muls with (-1) as operand should fold to neg.
+define <vscale x 16 x i8> @mul_neg_fold_i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: mul_neg_fold_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.b, p0/m, z0.b
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8 -1)
+  %2 = call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a, <vscale x 16 x i8> %1)
+  ret <vscale x 16 x i8> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+  ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @mul_neg_fold_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: mul_neg_fold_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -1)
+  %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+  ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @mul_neg_fold_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_two_dups(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+  ; Edge case -- make sure that the case where we're multiplying two dups
+  ; together is sane.
+; CHECK-LABEL: mul_neg_fold_two_dups:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.h, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    neg z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %3 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %1, <vscale x 8 x i16> %2)
+  ret <vscale x 8 x i16> %3
+}
+
+define <vscale x 16 x i8> @mul_neg_fold_u_i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.b, p0/m, z0.b
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8 -1)
+  %2 = call <vscale x 16 x i8> @llvm.aarch64.sve.mul.u.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a, <vscale x 16 x i8> %1)
+  ret <vscale x 16 x i8> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_u_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+  ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @mul_neg_fold_u_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.s, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -1)
+  %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+  ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @mul_neg_fold_u_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_u_two_dups(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_u_two_dups:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.h, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    neg z0.h, p0/m, z0.h
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+  %3 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %1, <vscale x 8 x i16> %2)
+  ret <vscale x 8 x i16> %3
+}
+
+; Undefined mul is commutative
+define <vscale x 2 x i64> @mul_neg_fold_u_different_argument_order(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_u_different_argument_order:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    neg z0.d, p0/m, z0.d
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %1, <vscale x 2 x i64> %a)
+  ret <vscale x 2 x i64> %2
+}
+; Non foldable muls -- we don't expect these to be optimised out.
+define <vscale x 8 x i16> @no_mul_neg_fold_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.h, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT:    mul z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -2)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+  ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @no_mul_neg_fold_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.s, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT:    mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -2)
+  %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+  ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @no_mul_neg_fold_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.d, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -2)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %2
+}
+
+; Merge mul is non commutative
+define <vscale x 2 x i64> @no_mul_neg_fold_different_argument_order(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_different_argument_order:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z1.d, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    mul z1.d, p0/m, z1.d, z0.d
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %1, <vscale x 2 x i64> %a)
+  ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @no_mul_neg_fold_u_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mul z0.h, z0.h, #-2
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -2)
+  %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+  ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @no_mul_neg_fold_u_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mul z0.s, z0.s, #-2
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -2)
+  %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+  ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @no_mul_neg_fold_u_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mul z0.d, z0.d, #-2
+; CHECK-NEXT:    ret
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -2)
+  %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %2
+}
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64)
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.u.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1>, <vscale x 8 x i16>, <vscale x 8 x i16>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1>, <vscale x 8 x i16>, <vscale x 8 x i16>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/160828


More information about the llvm-commits mailing list