[llvm] [DAG][AArch64] Handle vscale addressing modes in reassociationCanBreakAddressingModePattern (PR #89908)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 01:47:34 PDT 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/89908

>From 99fc4680a0824529c9d7bb74147ab9ced71703ef Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 1 May 2024 08:55:33 +0100
Subject: [PATCH] [DAG][AArch64] Handle vscale addressing modes in
 reassociationCanBreakAddressingModePattern.

reassociationCanBreakAddressingModePattern tries to prevent bad add
reassociations that would break adrressing mode patterns. This adds support for
vscale offset addressing modes, making sure we don't break patterns that
already exist. It does not optimize _to_ the correct addressing modes yet, but
prevents us from optimizating _away_ from them.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 42 ++++++++++++++-
 llvm/test/CodeGen/AArch64/sve-reassocadd.ll   | 54 +++++++------------
 2 files changed, 58 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c0bbea16a64262..245431f6cde1c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1083,7 +1083,44 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
   // (load/store (add, (add, x, y), offset2)) ->
   // (load/store (add, (add, x, offset2), y)).
 
-  if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
+  if (N0.getOpcode() != ISD::ADD)
+    return false;
+
+  // Check for vscale addressing modes.
+  // (load/store (add/sub (add x, y), vscale))
+  // (load/store (add/sub (add x, y), (lsl vscale, C)))
+  // (load/store (add/sub (add x, y), (mul vscale, C)))
+  if ((N1.getOpcode() == ISD::VSCALE ||
+       ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
+        N1.getOperand(0).getOpcode() == ISD::VSCALE &&
+        isa<ConstantSDNode>(N1.getOperand(1)))) &&
+      N1.getValueType().getFixedSizeInBits() <= 64) {
+    int64_t ScalableOffset =
+        N1.getOpcode() == ISD::VSCALE
+            ? N1.getConstantOperandVal(0)
+            : (N1.getOperand(0).getConstantOperandVal(0) *
+               (N1.getOpcode() == ISD::SHL ? (1 << N1.getConstantOperandVal(1))
+                                           : N1.getConstantOperandVal(1)));
+    if (Opc == ISD::SUB)
+      ScalableOffset = -ScalableOffset;
+    if (all_of(N->uses(), [&](SDNode *Node) {
+          if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
+              LoadStore && LoadStore->getBasePtr().getNode() == N) {
+            TargetLoweringBase::AddrMode AM;
+            AM.HasBaseReg = true;
+            AM.ScalableOffset = ScalableOffset;
+            EVT VT = LoadStore->getMemoryVT();
+            unsigned AS = LoadStore->getAddressSpace();
+            Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
+            return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
+                                             AS);
+          }
+          return false;
+        }))
+      return true;
+  }
+
+  if (Opc != ISD::ADD)
     return false;
 
   auto *C2 = dyn_cast<ConstantSDNode>(N1);
@@ -3911,7 +3948,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
 
   // Hoist one-use addition by non-opaque constant:
   //   (x + C) - y  ->  (x - y) + C
-  if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
+  if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
+      N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
diff --git a/llvm/test/CodeGen/AArch64/sve-reassocadd.ll b/llvm/test/CodeGen/AArch64/sve-reassocadd.ll
index c7261200a567a4..f54098b29a272f 100644
--- a/llvm/test/CodeGen/AArch64/sve-reassocadd.ll
+++ b/llvm/test/CodeGen/AArch64/sve-reassocadd.ll
@@ -22,11 +22,9 @@ entry:
 define <vscale x 16 x i8> @i8_4s_1v(ptr %b) {
 ; CHECK-LABEL: i8_4s_1v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    ptrue p0.b
-; CHECK-NEXT:    mov w9, #4 // =0x4
-; CHECK-NEXT:    add x8, x0, x8
-; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x8, x9]
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 4
@@ -58,11 +56,9 @@ entry:
 define <vscale x 8 x i16> @i16_8s_1v(ptr %b) {
 ; CHECK-LABEL: i16_8s_1v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    add x8, x0, x8
-; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
+; CHECK-NEXT:    add x8, x0, #8
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 8
@@ -94,11 +90,9 @@ entry:
 define <vscale x 8 x i16> @i16_8s_2v(ptr %b) {
 ; CHECK-LABEL: i16_8s_2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdvl x8, #2
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    add x8, x0, x8
-; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
+; CHECK-NEXT:    add x8, x0, #8
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, #2, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 8
@@ -130,11 +124,9 @@ entry:
 define <vscale x 4 x i32> @i32_16s_2v(ptr %b) {
 ; CHECK-LABEL: i32_16s_2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    add x8, x0, x8
-; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x8, x9, lsl #2]
+; CHECK-NEXT:    add x8, x0, #16
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 16
@@ -166,11 +158,9 @@ entry:
 define <vscale x 2 x i64> @i64_32s_2v(ptr %b) {
 ; CHECK-LABEL: i64_32s_2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    add x8, x0, x8
-; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, x9, lsl #3]
+; CHECK-NEXT:    add x8, x0, #32
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 32
@@ -203,11 +193,9 @@ entry:
 define <vscale x 16 x i8> @i8_4s_m2v(ptr %b) {
 ; CHECK-LABEL: i8_4s_m2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cnth x8, all, mul #4
 ; CHECK-NEXT:    ptrue p0.b
-; CHECK-NEXT:    mov w9, #4 // =0x4
-; CHECK-NEXT:    sub x8, x0, x8
-; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x8, x9]
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    ld1b { z0.b }, p0/z, [x8, #-2, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 4
@@ -239,11 +227,9 @@ entry:
 define <vscale x 8 x i16> @i16_8s_m2v(ptr %b) {
 ; CHECK-LABEL: i16_8s_m2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cnth x8, all, mul #4
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    sub x8, x0, x8
-; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, x9, lsl #1]
+; CHECK-NEXT:    add x8, x0, #8
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x8, #-2, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 8
@@ -275,11 +261,9 @@ entry:
 define <vscale x 4 x i32> @i32_16s_m2v(ptr %b) {
 ; CHECK-LABEL: i32_16s_m2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cnth x8, all, mul #4
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    sub x8, x0, x8
-; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x8, x9, lsl #2]
+; CHECK-NEXT:    add x8, x0, #16
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x8, #-2, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 16
@@ -311,11 +295,9 @@ entry:
 define <vscale x 2 x i64> @i64_32s_m2v(ptr %b) {
 ; CHECK-LABEL: i64_32s_m2v:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cnth x8, all, mul #4
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    mov x9, #4 // =0x4
-; CHECK-NEXT:    sub x8, x0, x8
-; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, x9, lsl #3]
+; CHECK-NEXT:    add x8, x0, #32
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, #-2, mul vl]
 ; CHECK-NEXT:    ret
 entry:
   %add.ptr = getelementptr inbounds i8, ptr %b, i64 32



More information about the llvm-commits mailing list