[llvm] 3098200 - [ISel] Propagate disjoint flag in ShrinkDemandedOp (#114560)

via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 3 11:42:08 PST 2024


Author: Sander de Smalen
Date: 2024-11-03T19:42:04Z
New Revision: 3098200fccabc781c68c0119ce33c89b500f6272

URL: https://github.com/llvm/llvm-project/commit/3098200fccabc781c68c0119ce33c89b500f6272
DIFF: https://github.com/llvm/llvm-project/commit/3098200fccabc781c68c0119ce33c89b500f6272.diff

LOG: [ISel] Propagate disjoint flag in ShrinkDemandedOp (#114560)

When trying to evaluate an expression in a narrower type, the
DAGCombine should propagate the disjoint flag, as it's equally
valid on the narrower expression.

This helps improve better use of addressing modes for some
Arm SME instructions, for example.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index b2d74603e653d8..f21233abfa4f5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -604,10 +604,15 @@ bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
     EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
     if (isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT)) {
       // We found a type with free casts.
+
+      // If the operation has the 'disjoint' flag, then the
+      // operands on the new node are also disjoint.
+      SDNodeFlags Flags(Op->getFlags().hasDisjoint() ? SDNodeFlags::Disjoint
+                                                     : SDNodeFlags::None);
       SDValue X = DAG.getNode(
           Op.getOpcode(), dl, SmallVT,
           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
-          DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
+          DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)), Flags);
       assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
       SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
       return TLO.CombineTo(Op, Z);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2120443b6ba2c5..511ab4fe7e9a39 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -7420,7 +7420,7 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
                                              SDValue &Base, SDValue &Offset,
                                              unsigned Scale) {
   // Try to untangle an ADD node into a 'reg + offset'
-  if (N.getOpcode() == ISD::ADD)
+  if (CurDAG->isBaseWithConstantOffset(N))
     if (auto C = dyn_cast<ConstantSDNode>(N.getOperand(1))) {
       int64_t ImmOff = C->getSExtValue();
       if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) {

diff  --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
index 8711a0388e34c5..9c5becf6ffbf29 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll
@@ -470,6 +470,22 @@ exit:
   ret void
 }
 
+define void @test_add_with_disjoint_or(i64 %idx, <vscale x 4 x i1> %pg) {
+; CHECK-LABEL: test_add_with_disjoint_or:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.s, #0 // =0x0
+; CHECK-NEXT:    mov x12, x0
+; CHECK-NEXT:    mov za0h.s[w12, 0], p0/m, z0.s
+; CHECK-NEXT:    mov za0h.s[w12, 1], p0/m, z0.s
+; CHECK-NEXT:    ret
+  %idx.trunc = trunc i64 %idx to i32
+  call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
+  %idx2 = or disjoint i64 %idx, 1
+  %idx2.trunc = trunc i64 %idx2 to i32
+  call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx2.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.aarch64.sme.write.horiz.nxv16i8(i32, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
 declare void @llvm.aarch64.sme.write.horiz.nxv8i16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
 declare void @llvm.aarch64.sme.write.horiz.nxv8f16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x half>)


        


More information about the llvm-commits mailing list