[llvm] f74aed7 - [DAGCombiner] Add basic support for `trunc nsw/nuw` (#113808)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 08:23:58 PST 2024


Author: Yingwei Zheng
Date: 2024-11-07T00:23:53+08:00
New Revision: f74aed793819bf9e0509e802f33c5e29c350540c

URL: https://github.com/llvm/llvm-project/commit/f74aed793819bf9e0509e802f33c5e29c350540c
DIFF: https://github.com/llvm/llvm-project/commit/f74aed793819bf9e0509e802f33c5e29c350540c.diff

LOG: [DAGCombiner] Add basic support for `trunc nsw/nuw` (#113808)

This patch adds basic support for `trunc nsw/nuw` in SDAG. It will allow
DAGCombiner to further eliminate in-reg `zext/sext` instructions.

Added: 
    llvm/test/CodeGen/AArch64/trunc-nsw-nuw.ll
    llvm/test/CodeGen/RISCV/trunc-nsw-nuw.ll
    llvm/test/CodeGen/X86/trunc-nsw-nuw.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index dcd5ca3b936e72..0b5cd2c33bf35a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2330,6 +2330,8 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
   if (N->getOpcode() == ISD::TRUNCATE) {
     Op = N->getOperand(0);
     Known = DAG.computeKnownBits(Op);
+    if (N->getFlags().hasNoUnsignedWrap())
+      Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
     return true;
   }
 
@@ -13889,23 +13891,27 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
     unsigned OpBits   = Op.getScalarValueSizeInBits();
     unsigned MidBits  = N0.getScalarValueSizeInBits();
     unsigned DestBits = VT.getScalarSizeInBits();
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
 
-    if (OpBits == DestBits) {
-      // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
-      // bits, it is already ready.
-      if (NumSignBits > DestBits-MidBits)
+    if (N0->getFlags().hasNoSignedWrap() ||
+        DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
+      if (OpBits == DestBits) {
+        // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
+        // bits, it is already ready.
         return Op;
-    } else if (OpBits < DestBits) {
-      // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
-      // bits, just sext from i32.
-      if (NumSignBits > OpBits-MidBits)
+      }
+
+      if (OpBits < DestBits) {
+        // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
+        // bits, just sext from i32.
         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
-    } else {
+      }
+
       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
       // bits, just truncate to i32.
-      if (NumSignBits > OpBits-MidBits)
-        return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
+      SDNodeFlags Flags;
+      Flags.setNoSignedWrap(true);
+      Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
+      return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
     }
 
     // fold (sext (truncate x)) -> (sextinreg x).
@@ -14176,24 +14182,28 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
       unsigned OpBits = SrcVT.getScalarSizeInBits();
       unsigned MidBits = MinVT.getScalarSizeInBits();
       unsigned DestBits = VT.getScalarSizeInBits();
-      unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
 
-      if (OpBits == DestBits) {
-        // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
-        // bits, it is already ready.
-        if (NumSignBits > DestBits - MidBits)
+      if (N0->getFlags().hasNoSignedWrap() ||
+          DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
+        if (OpBits == DestBits) {
+          // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
+          // bits, it is already ready.
           return Op;
-      } else if (OpBits < DestBits) {
-        // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
-        // bits, just sext from i32.
-        // FIXME: This can probably be ZERO_EXTEND nneg?
-        if (NumSignBits > OpBits - MidBits)
+        }
+
+        if (OpBits < DestBits) {
+          // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
+          // bits, just sext from i32.
+          // FIXME: This can probably be ZERO_EXTEND nneg?
           return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
-      } else {
+        }
+
         // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
         // bits, just truncate to i32.
-        if (NumSignBits > OpBits - MidBits)
-          return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
+        SDNodeFlags Flags;
+        Flags.setNoSignedWrap(true);
+        Flags.setNoUnsignedWrap(true);
+        return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
       }
     }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 0b889514ad60d8..3b046aa25f5444 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3826,7 +3826,13 @@ void SelectionDAGBuilder::visitTrunc(const User &I) {
   SDValue N = getValue(I.getOperand(0));
   EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
                                                         I.getType());
-  setValue(&I, DAG.getNode(ISD::TRUNCATE, getCurSDLoc(), DestVT, N));
+  SDNodeFlags Flags;
+  if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
+    Flags.setNoSignedWrap(Trunc->hasNoSignedWrap());
+    Flags.setNoUnsignedWrap(Trunc->hasNoUnsignedWrap());
+  }
+
+  setValue(&I, DAG.getNode(ISD::TRUNCATE, getCurSDLoc(), DestVT, N, Flags));
 }
 
 void SelectionDAGBuilder::visitZExt(const User &I) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 0360c1bd76f007..8287565336b54d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2604,8 +2604,12 @@ bool TargetLowering::SimplifyDemandedBits(
     unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
     APInt TruncMask = DemandedBits.zext(OperandBitWidth);
     if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
-                             Depth + 1))
+                             Depth + 1)) {
+      // Disable the nsw and nuw flags. We can no longer guarantee that we
+      // won't wrap after simplification.
+      Op->dropFlags(SDNodeFlags::NoWrap);
       return true;
+    }
     Known = Known.trunc(BitWidth);
 
     // Attempt to avoid multi-use ops if we don't need anything from them.

diff  --git a/llvm/test/CodeGen/AArch64/trunc-nsw-nuw.ll b/llvm/test/CodeGen/AArch64/trunc-nsw-nuw.ll
new file mode 100644
index 00000000000000..6041db74639f32
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/trunc-nsw-nuw.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=aarch64-- | FileCheck %s
+
+define zeroext i32 @trunc_nuw_nsw_urem(i64 %x) nounwind {
+; CHECK-LABEL: trunc_nuw_nsw_urem:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov w8, #5977 // =0x1759
+; CHECK-NEXT:    mov w9, #10000 // =0x2710
+; CHECK-NEXT:    movk w8, #53687, lsl #16
+; CHECK-NEXT:    mul x8, x0, x8
+; CHECK-NEXT:    lsr x8, x8, #45
+; CHECK-NEXT:    msub w0, w8, w9, w0
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %rem = urem i32 %trunc, 10000
+  ret i32 %rem
+}
+
+define i64 @zext_nneg_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: zext_nneg_udiv_trunc_nuw:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov w8, #52429 // =0xcccd
+; CHECK-NEXT:    mul w8, w0, w8
+; CHECK-NEXT:    lsr w0, w8, #23
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = zext nneg i16 %div to i64
+  ret i64 %ext
+}
+
+define i64 @sext_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: sext_udiv_trunc_nuw:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov w8, #52429 // =0xcccd
+; CHECK-NEXT:    mul w8, w0, w8
+; CHECK-NEXT:    lsr w0, w8, #23
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = sext i16 %div to i64
+  ret i64 %ext
+}
+
+define ptr @gep_nusw_zext_nneg_add_trunc_nuw_nsw(ptr %p, i64 %x) nounwind {
+; CHECK-LABEL: gep_nusw_zext_nneg_add_trunc_nuw_nsw:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add w8, w1, #5
+; CHECK-NEXT:    add x0, x0, w8, uxtw #2
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %add = add nuw nsw i32 %trunc, 5
+  %offset = zext nneg i32 %add to i64
+  %gep = getelementptr nusw float, ptr %p, i64 %offset
+  ret ptr %gep
+}

diff  --git a/llvm/test/CodeGen/RISCV/trunc-nsw-nuw.ll b/llvm/test/CodeGen/RISCV/trunc-nsw-nuw.ll
new file mode 100644
index 00000000000000..f270775adcc155
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/trunc-nsw-nuw.ll
@@ -0,0 +1,78 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=riscv64 -mattr=+m | FileCheck %s
+
+define signext i8 @trunc_nsw_add(i32 signext %x) nounwind {
+; CHECK-LABEL: trunc_nsw_add:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addiw a0, a0, 1
+; CHECK-NEXT:    ret
+entry:
+  %add = add nsw i32 %x, 1
+  %trunc = trunc nsw i32 %add to i8
+  ret i8 %trunc
+}
+
+define signext i32 @trunc_nuw_nsw_urem(i64 %x) nounwind {
+; CHECK-LABEL: trunc_nuw_nsw_urem:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 210
+; CHECK-NEXT:    addiw a1, a1, -1167
+; CHECK-NEXT:    slli a1, a1, 12
+; CHECK-NEXT:    addi a1, a1, 1881
+; CHECK-NEXT:    mul a1, a0, a1
+; CHECK-NEXT:    srli a1, a1, 45
+; CHECK-NEXT:    lui a2, 2
+; CHECK-NEXT:    addi a2, a2, 1808
+; CHECK-NEXT:    mul a1, a1, a2
+; CHECK-NEXT:    subw a0, a0, a1
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %rem = urem i32 %trunc, 10000
+  ret i32 %rem
+}
+
+define i64 @zext_nneg_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: zext_nneg_udiv_trunc_nuw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 13
+; CHECK-NEXT:    addi a1, a1, -819
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    srliw a0, a0, 23
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = zext nneg i16 %div to i64
+  ret i64 %ext
+}
+
+define i64 @sext_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: sext_udiv_trunc_nuw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a1, 13
+; CHECK-NEXT:    addi a1, a1, -819
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    srliw a0, a0, 23
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = sext i16 %div to i64
+  ret i64 %ext
+}
+
+define ptr @gep_nusw_zext_nneg_add_trunc_nuw_nsw(ptr %p, i64 %x) nounwind {
+; CHECK-LABEL: gep_nusw_zext_nneg_add_trunc_nuw_nsw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    slli a1, a1, 2
+; CHECK-NEXT:    add a0, a1, a0
+; CHECK-NEXT:    addi a0, a0, 20
+; CHECK-NEXT:    ret
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %add = add nuw nsw i32 %trunc, 5
+  %offset = zext nneg i32 %add to i64
+  %gep = getelementptr nusw float, ptr %p, i64 %offset
+  ret ptr %gep
+}

diff  --git a/llvm/test/CodeGen/X86/trunc-nsw-nuw.ll b/llvm/test/CodeGen/X86/trunc-nsw-nuw.ll
new file mode 100644
index 00000000000000..5c5f7045ea0306
--- /dev/null
+++ b/llvm/test/CodeGen/X86/trunc-nsw-nuw.ll
@@ -0,0 +1,83 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64 | FileCheck %s
+
+define zeroext i32 @trunc_nuw_nsw_urem(i64 %x) nounwind {
+; CHECK-LABEL: trunc_nuw_nsw_urem:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    movq %rdi, %rax
+; CHECK-NEXT:    movl $3518437209, %ecx # imm = 0xD1B71759
+; CHECK-NEXT:    imulq %rdi, %rcx
+; CHECK-NEXT:    shrq $45, %rcx
+; CHECK-NEXT:    imull $10000, %ecx, %ecx # imm = 0x2710
+; CHECK-NEXT:    subl %ecx, %eax
+; CHECK-NEXT:    # kill: def $eax killed $eax killed $rax
+; CHECK-NEXT:    retq
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %rem = urem i32 %trunc, 10000
+  ret i32 %rem
+}
+
+define i64 @zext_nneg_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: zext_nneg_udiv_trunc_nuw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    imull $52429, %edi, %eax # imm = 0xCCCD
+; CHECK-NEXT:    shrl $23, %eax
+; CHECK-NEXT:    retq
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = zext nneg i16 %div to i64
+  ret i64 %ext
+}
+
+define i64 @sext_udiv_trunc_nuw(i64 %x) nounwind {
+; CHECK-LABEL: sext_udiv_trunc_nuw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    imull $52429, %edi, %eax # imm = 0xCCCD
+; CHECK-NEXT:    shrl $23, %eax
+; CHECK-NEXT:    retq
+entry:
+  %trunc = trunc nuw i64 %x to i16
+  %div = udiv i16 %trunc, 160
+  %ext = sext i16 %div to i64
+  ret i64 %ext
+}
+
+define ptr @gep_nusw_zext_nneg_add_trunc_nuw_nsw(ptr %p, i64 %x) nounwind {
+; CHECK-LABEL: gep_nusw_zext_nneg_add_trunc_nuw_nsw:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    leaq 20(%rdi,%rsi,4), %rax
+; CHECK-NEXT:    retq
+entry:
+  %trunc = trunc nuw nsw i64 %x to i32
+  %add = add nuw nsw i32 %trunc, 5
+  %offset = zext nneg i32 %add to i64
+  %gep = getelementptr nusw float, ptr %p, i64 %offset
+  ret ptr %gep
+}
+
+; Make sure nsw flag is dropped after we simplify the operand of TRUNCATE.
+
+define i32 @simplify_demanded_bits_drop_flag(i1 zeroext %x, i1 zeroext %y) nounwind {
+; CHECK-LABEL: simplify_demanded_bits_drop_flag:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    negl %edi
+; CHECK-NEXT:    shll $2, %esi
+; CHECK-NEXT:    xorl %edi, %esi
+; CHECK-NEXT:    movslq %esi, %rax
+; CHECK-NEXT:    imulq $-1634202141, %rax, %rax # imm = 0x9E980DE3
+; CHECK-NEXT:    movq %rax, %rcx
+; CHECK-NEXT:    shrq $63, %rcx
+; CHECK-NEXT:    sarq $44, %rax
+; CHECK-NEXT:    addl %ecx, %eax
+; CHECK-NEXT:    # kill: def $eax killed $eax killed $rax
+; CHECK-NEXT:    retq
+entry:
+  %sel = select i1 %y, i64 4, i64 0
+  %conv0 = sext i1 %x to i64
+  %xor = xor i64 %sel, %conv0
+  %conv1 = trunc nsw i64 %xor to i32
+  %div = sdiv i32 %conv1, -10765
+  ret i32 %div
+}


        


More information about the llvm-commits mailing list