[llvm] [ISel] Propagate disjoint flag in ShrinkDemandedOp (PR #114560)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 09:33:21 PDT 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/114560

>From 600b107abad428010a0ea8f933a84c8dc60abbd2 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 1 Nov 2024 15:33:33 +0000
Subject: [PATCH 1/3] Pre-commit test

---
 .../AArch64/sme-intrinsics-mova-insert.ll       | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
index 8711a0388e34c5..f5953ee74561b5 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
@@ -470,6 +470,23 @@ exit:
   ret void
 }
 
+define void @test_add_with_disjoint_or(i64 %idx, <vscale x 4 x i1> %pg) {
+; CHECK-LABEL: test_add_with_disjoint_or:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.s, #0 // =0x0
+; CHECK-NEXT:    mov x12, x0
+; CHECK-NEXT:    orr w13, w12, #0x1
+; CHECK-NEXT:    mov za0h.s[w12, 0], p0/m, z0.s
+; CHECK-NEXT:    mov za0h.s[w13, 0], p0/m, z0.s
+; CHECK-NEXT:    ret
+  %idx.trunc = trunc i64 %idx to i32
+  call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
+  %idx2 = or disjoint i64 %idx, 1
+  %idx2.trunc = trunc i64 %idx2 to i32
+  call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx2.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.aarch64.sme.write.horiz.nxv16i8(i32, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
 declare void @llvm.aarch64.sme.write.horiz.nxv8i16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
 declare void @llvm.aarch64.sme.write.horiz.nxv8f16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x half>)

>From 712d18f00d3879f4fc2f07a250d2602df3fce720 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 24 Oct 2024 14:47:09 +0100
Subject: [PATCH 2/3] [ISel] Pass on disjoint flag in ShrinkDemandedOp

When trying to evaluate an expression in a narrower type, the
DAGCombine should propagate the disjoint flag, as it's equally
valid on the narrower expression.

This helps improve better use of addressing modes for some
Arm SME instructions, for example.
---
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp        | 5 +++++
 llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp         | 2 +-
 llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll | 3 +--
 3 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 1f4ace1b3174dd..941136f34c154e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -609,6 +609,11 @@ bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
           Op.getOpcode(), dl, SmallVT,
           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
+      // If the operation has the 'disjoint' flag, then the operands on
+      // the new node are also disjoint.
+      SDNodeFlags Flags = Op->getFlags();
+      X->setFlags(Flags.hasDisjoint() ? SDNodeFlags::Disjoint
+                                      : SDNodeFlags::None);
       assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
       SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
       return TLO.CombineTo(Op, Z);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2120443b6ba2c5..511ab4fe7e9a39 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -7420,7 +7420,7 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
                                              SDValue &Base, SDValue &Offset,
                                              unsigned Scale) {
   // Try to untangle an ADD node into a 'reg + offset'
-  if (N.getOpcode() == ISD::ADD)
+  if (CurDAG->isBaseWithConstantOffset(N))
     if (auto C = dyn_cast<ConstantSDNode>(N.getOperand(1))) {
       int64_t ImmOff = C->getSExtValue();
       if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) {
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
index f5953ee74561b5..9c5becf6ffbf29 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
@@ -475,9 +475,8 @@ define void @test_add_with_disjoint_or(i64 %idx, <vscale x 4 x i1> %pg) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov z0.s, #0 // =0x0
 ; CHECK-NEXT:    mov x12, x0
-; CHECK-NEXT:    orr w13, w12, #0x1
 ; CHECK-NEXT:    mov za0h.s[w12, 0], p0/m, z0.s
-; CHECK-NEXT:    mov za0h.s[w13, 0], p0/m, z0.s
+; CHECK-NEXT:    mov za0h.s[w12, 1], p0/m, z0.s
 ; CHECK-NEXT:    ret
   %idx.trunc = trunc i64 %idx to i32
   call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)

>From 8b58ab9b99f45267b48e5ccc4300a088f0020134 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 1 Nov 2024 16:32:08 +0000
Subject: [PATCH 3/3] Pass flags directly into getNode()

---
 llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 941136f34c154e..59a82a57292e23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -605,15 +605,15 @@ bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
     EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
     if (isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT)) {
       // We found a type with free casts.
+
+      // If the operation has the 'disjoint' flag, then the
+      // operands on the new node are also disjoint.
+      SDNodeFlags Flags(Op->getFlags().hasDisjoint() ? SDNodeFlags::Disjoint
+                                                     : SDNodeFlags::None);
       SDValue X = DAG.getNode(
           Op.getOpcode(), dl, SmallVT,
           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
-          DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
-      // If the operation has the 'disjoint' flag, then the operands on
-      // the new node are also disjoint.
-      SDNodeFlags Flags = Op->getFlags();
-      X->setFlags(Flags.hasDisjoint() ? SDNodeFlags::Disjoint
-                                      : SDNodeFlags::None);
+          DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)), Flags);
       assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
       SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
       return TLO.CombineTo(Op, Z);



More information about the llvm-commits mailing list