[llvm] [AArch64][SVE2] Lower read-after-write mask to whilerw (PR #114028)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 03:30:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Sam Tebbs (SamTebbs33)

<details>
<summary>Changes</summary>

This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.

---
Full diff: https://github.com/llvm/llvm-project/pull/114028.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+26-8) 
- (modified) llvm/test/CodeGen/AArch64/whilewr.ll (+133) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bf2f0674b5b65e..a2517761afc0c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14189,7 +14189,16 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
     return SDValue();
 
   SDValue Diff = Cmp.getOperand(0);
-  if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+  SDValue NonAbsDiff = Diff;
+  bool WriteAfterRead = true;
+  // A read-after-write will have an abs call on the diff
+  if (Diff.getOpcode() == ISD::ABS) {
+    NonAbsDiff = Diff.getOperand(0);
+    WriteAfterRead = false;
+  }
+
+  if (NonAbsDiff.getOpcode() != ISD::SUB ||
+      NonAbsDiff.getValueType() != MVT::i64)
     return SDValue();
 
   if (!isNullConstant(LaneMask.getOperand(1)) ||
@@ -14210,8 +14219,13 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
       // it's positive, otherwise the difference plus the element size if it's
       // negative: pos_diff = diff < 0 ? (diff + 7) : diff
       SDValue Select = DiffDiv.getOperand(0);
+      SDValue SelectOp3 = Select.getOperand(3);
+      // Check for an abs in the case of a read-after-write
+      if (!WriteAfterRead && SelectOp3.getOpcode() == ISD::ABS)
+        SelectOp3 = SelectOp3.getOperand(0);
+
       // Make sure the difference is being compared by the select
-      if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
+      if (Select.getOpcode() != ISD::SELECT_CC || SelectOp3 != NonAbsDiff)
         return SDValue();
       // Make sure it's checking if the difference is less than 0
       if (!isNullConstant(Select.getOperand(1)) ||
@@ -14243,22 +14257,26 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
   } else if (LaneMask.getOperand(2) != Diff)
     return SDValue();
 
-  SDValue StorePtr = Diff.getOperand(0);
-  SDValue ReadPtr = Diff.getOperand(1);
+  SDValue StorePtr = NonAbsDiff.getOperand(0);
+  SDValue ReadPtr = NonAbsDiff.getOperand(1);
 
   unsigned IntrinsicID = 0;
   switch (EltSize) {
   case 1:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_b
+                                 : Intrinsic::aarch64_sve_whilerw_b;
     break;
   case 2:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_h
+                                 : Intrinsic::aarch64_sve_whilerw_h;
     break;
   case 4:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_s
+                                 : Intrinsic::aarch64_sve_whilerw_s;
     break;
   case 8:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_d
+                                 : Intrinsic::aarch64_sve_whilerw_d;
     break;
   default:
     return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/whilewr.ll b/llvm/test/CodeGen/AArch64/whilewr.ll
index 9f1ea850792384..bdea64296455eb 100644
--- a/llvm/test/CodeGen/AArch64/whilewr.ll
+++ b/llvm/test/CodeGen/AArch64/whilewr.ll
@@ -30,6 +30,36 @@ entry:
   ret <vscale x 16 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 16 x i1> @whilerw_8(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_8:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.b, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_8:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    cmp x8, #1
+; CHECK-NOSVE2-NEXT:    cset w9, lt
+; CHECK-NOSVE2-NEXT:    whilelo p0.b, xzr, x8
+; CHECK-NOSVE2-NEXT:    sbfx x8, x9, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.b, xzr, x8
+; CHECK-NOSVE2-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %neg.compare = icmp slt i64 %0, 1
+  %.splatinsert = insertelement <vscale x 16 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 16 x i1> %.splatinsert, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %0)
+  %active.lane.mask.alias = or <vscale x 16 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 16 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 16 x i1> @whilewr_commutative(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_commutative:
 ; CHECK:       // %bb.0: // %entry
@@ -89,6 +119,39 @@ entry:
   ret <vscale x 8 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 8 x i1> @whilerw_16(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_16:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.h, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_16:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    cmp x8, #2
+; CHECK-NOSVE2-NEXT:    add x8, x8, x8, lsr #63
+; CHECK-NOSVE2-NEXT:    cset w9, lt
+; CHECK-NOSVE2-NEXT:    sbfx x9, x9, #0, #1
+; CHECK-NOSVE2-NEXT:    asr x8, x8, #1
+; CHECK-NOSVE2-NEXT:    whilelo p0.h, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p1.h, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 2
+  %neg.compare = icmp slt i64 %0, 2
+  %.splatinsert = insertelement <vscale x 8 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 8 x i1> %.splatinsert, <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 8 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 8 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 4 x i1> @whilewr_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_32:
 ; CHECK:       // %bb.0: // %entry
@@ -122,6 +185,41 @@ entry:
   ret <vscale x 4 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 4 x i1> @whilerw_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_32:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.s, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_32:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    add x9, x8, #3
+; CHECK-NOSVE2-NEXT:    cmp x8, #0
+; CHECK-NOSVE2-NEXT:    csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT:    cmp x8, #4
+; CHECK-NOSVE2-NEXT:    cset w8, lt
+; CHECK-NOSVE2-NEXT:    asr x9, x9, #2
+; CHECK-NOSVE2-NEXT:    sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.s, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p0.s, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 4
+  %neg.compare = icmp slt i64 %0, 4
+  %.splatinsert = insertelement <vscale x 4 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 4 x i1> %.splatinsert, <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 4 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 4 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 2 x i1> @whilewr_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_64:
 ; CHECK:       // %bb.0: // %entry
@@ -155,6 +253,41 @@ entry:
   ret <vscale x 2 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 2 x i1> @whilerw_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_64:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.d, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_64:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    add x9, x8, #7
+; CHECK-NOSVE2-NEXT:    cmp x8, #0
+; CHECK-NOSVE2-NEXT:    csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT:    cmp x8, #8
+; CHECK-NOSVE2-NEXT:    cset w8, lt
+; CHECK-NOSVE2-NEXT:    asr x9, x9, #3
+; CHECK-NOSVE2-NEXT:    sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.d, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p0.d, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 8
+  %neg.compare = icmp slt i64 %0, 8
+  %.splatinsert = insertelement <vscale x 2 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 2 x i1> %.splatinsert, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 2 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 2 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 1 x i1> @no_whilewr_128(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: no_whilewr_128:
 ; CHECK:       // %bb.0: // %entry

``````````

</details>


https://github.com/llvm/llvm-project/pull/114028


More information about the llvm-commits mailing list