[llvm] [LegalizeDAG] Optimize CodeGen for `ISD::CTLZ_ZERO_UNDEF` (PR #83039)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 26 09:33:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Manish Kausik H (Nirhar)

<details>
<summary>Changes</summary>

Previously we had the same instructions being generated for `ISD::CTLZ` and `ISD::CTLZ_ZERO_UNDEF` which did not take advantage of the fact that zero is an invalid input for `ISD::CTLZ_ZERO_UNDEF`. This commit separates codegen for the two cases to allow for the optimization for the latter case.

The details of the optimization are outlined in #<!-- -->82075

Fixes #<!-- -->82075

---
Full diff: https://github.com/llvm/llvm-project/pull/83039.diff


5 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+16-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+35-14) 
- (added) llvm/test/CodeGen/AArch64/ctlz_zero_undef.ll (+18) 
- (modified) llvm/test/CodeGen/X86/clz.ll (+17-27) 
- (modified) llvm/test/CodeGen/X86/lzcnt.ll (+7-9) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 6272c3093cff6c..811eb4c3fddf3b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4973,7 +4973,6 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   case ISD::CTTZ:
   case ISD::CTTZ_ZERO_UNDEF:
   case ISD::CTLZ:
-  case ISD::CTLZ_ZERO_UNDEF:
   case ISD::CTPOP:
     // Zero extend the argument unless its cttz, then use any_extend.
     if (Node->getOpcode() == ISD::CTTZ ||
@@ -5003,6 +5002,22 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     }
     Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
     break;
+  case ISD::CTLZ_ZERO_UNDEF:
+    // We know that the argument is unlikely to be zero, hence we can take a
+    // different approach as compared to ISD::CTLZ
+
+    // Zero Extend the argument
+    Tmp1 = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Node->getOperand(0));
+
+    // Tmp1 = Tmp1 << (sizeinbits(NVT) - sizeinbits(Old VT))
+    Tmp1 = DAG.getNode(
+        ISD::SHL, dl, NVT, Tmp1,
+        DAG.getConstant(NVT.getSizeInBits() - OVT.getSizeInBits(), dl, NVT));
+
+    // Perform the larger operation
+    Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1);
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
+    break;
   case ISD::BITREVERSE:
   case ISD::BSWAP: {
     unsigned DiffBits = NVT.getSizeInBits() - OVT.getSizeInBits();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index a4ba261686c688..7848821efb51df 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -613,21 +613,42 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
     }
   }
 
-  // Zero extend to the promoted type and do the count there.
-  SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+  unsigned CtlzOpcode = N->getOpcode();
+  if (CtlzOpcode == ISD::CTLZ) {
+    // Zero extend to the promoted type and do the count there.
+    SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+
+    // Subtract off the extra leading bits in the bigger type.
+    SDValue ExtractLeadingBits = DAG.getConstant(
+        NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
+    if (!N->isVPOpcode())
+      return DAG.getNode(ISD::SUB, dl, NVT,
+                         DAG.getNode(N->getOpcode(), dl, NVT, Op),
+                         ExtractLeadingBits);
+    SDValue Mask = N->getOperand(1);
+    SDValue EVL = N->getOperand(2);
+    return DAG.getNode(ISD::VP_SUB, dl, NVT,
+                       DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
+                       ExtractLeadingBits, Mask, EVL);
+  } else if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF) {
+    // Zero Extend the argument
+    SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+
+    // Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
+    auto ShiftConst =
+        DAG.getConstant(NVT.getSizeInBits() - OVT.getSizeInBits(), dl, NVT);
+    if (!N->isVPOpcode()) {
+      Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst);
+      return DAG.getNode(CtlzOpcode, dl, NVT, Op);
+    }
 
-  // Subtract off the extra leading bits in the bigger type.
-  SDValue ExtractLeadingBits = DAG.getConstant(
-      NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
-  if (!N->isVPOpcode())
-    return DAG.getNode(ISD::SUB, dl, NVT,
-                       DAG.getNode(N->getOpcode(), dl, NVT, Op),
-                       ExtractLeadingBits);
-  SDValue Mask = N->getOperand(1);
-  SDValue EVL = N->getOperand(2);
-  return DAG.getNode(ISD::VP_SUB, dl, NVT,
-                     DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
-                     ExtractLeadingBits, Mask, EVL);
+    SDValue Mask = N->getOperand(1);
+    SDValue EVL = N->getOperand(2);
+    Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst, Mask, EVL);
+    return DAG.getNode(CtlzOpcode, dl, NVT, Op, Mask, EVL);
+  } else {
+    llvm_unreachable("Invalid CTLZ Opcode");
+  }
 }
 
 SDValue DAGTypeLegalizer::PromoteIntRes_CTPOP_PARITY(SDNode *N) {
diff --git a/llvm/test/CodeGen/AArch64/ctlz_zero_undef.ll b/llvm/test/CodeGen/AArch64/ctlz_zero_undef.ll
new file mode 100644
index 00000000000000..f433e3457c3817
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/ctlz_zero_undef.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s --mtriple=aarch64 | FileCheck %s
+
+declare i8 @llvm.ctlz.i8(i8, i1 immarg)
+
+define i32 @clz_nzu8(ptr %num) {
+; CHECK-LABEL: clz_nzu8:
+; CHECK:       // %bb.0: // %start
+; CHECK-NEXT:    ldrb w8, [x0]
+; CHECK-NEXT:    lsl w8, w8, #24
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+start:
+  %self = load i8, ptr %num, align 1
+  %0 = call i8 @llvm.ctlz.i8(i8 %self, i1 true)
+  %_0 = zext i8 %0 to i32
+  ret i32 %_0
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/X86/clz.ll b/llvm/test/CodeGen/X86/clz.ll
index 92cbc165902473..39403ad9947300 100644
--- a/llvm/test/CodeGen/X86/clz.ll
+++ b/llvm/test/CodeGen/X86/clz.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc < %s -mtriple=i686-unknown-unknown | FileCheck %s --check-prefixes=X86,X86-NOCMOV
 ; RUN: llc < %s -mtriple=i686-unknown-unknown -mattr=+cmov | FileCheck %s --check-prefixes=X86,X86-CMOV
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefix=X64
@@ -225,33 +225,31 @@ define i8 @ctlz_i8(i8 %x) {
 ;
 ; X86-CLZ-LABEL: ctlz_i8:
 ; X86-CLZ:       # %bb.0:
-; X86-CLZ-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
+; X86-CLZ-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-CLZ-NEXT:    shll $24, %eax
 ; X86-CLZ-NEXT:    lzcntl %eax, %eax
-; X86-CLZ-NEXT:    addl $-24, %eax
 ; X86-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-CLZ-NEXT:    retl
 ;
 ; X64-CLZ-LABEL: ctlz_i8:
 ; X64-CLZ:       # %bb.0:
-; X64-CLZ-NEXT:    movzbl %dil, %eax
-; X64-CLZ-NEXT:    lzcntl %eax, %eax
-; X64-CLZ-NEXT:    addl $-24, %eax
+; X64-CLZ-NEXT:    shll $24, %edi
+; X64-CLZ-NEXT:    lzcntl %edi, %eax
 ; X64-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-CLZ-NEXT:    retq
 ;
 ; X64-FASTLZCNT-LABEL: ctlz_i8:
 ; X64-FASTLZCNT:       # %bb.0:
-; X64-FASTLZCNT-NEXT:    movzbl %dil, %eax
-; X64-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X64-FASTLZCNT-NEXT:    addl $-24, %eax
+; X64-FASTLZCNT-NEXT:    shll $24, %edi
+; X64-FASTLZCNT-NEXT:    lzcntl %edi, %eax
 ; X64-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-FASTLZCNT-NEXT:    retq
 ;
 ; X86-FASTLZCNT-LABEL: ctlz_i8:
 ; X86-FASTLZCNT:       # %bb.0:
-; X86-FASTLZCNT-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
+; X86-FASTLZCNT-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-FASTLZCNT-NEXT:    shll $24, %eax
 ; X86-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X86-FASTLZCNT-NEXT:    addl $-24, %eax
 ; X86-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-FASTLZCNT-NEXT:    retl
   %tmp2 = call i8 @llvm.ctlz.i8( i8 %x, i1 true )
@@ -831,7 +829,6 @@ define i64 @cttz_i64_zero_test(i64 %n) {
 ; X86-NOCMOV-LABEL: cttz_i64_zero_test:
 ; X86-NOCMOV:       # %bb.0:
 ; X86-NOCMOV-NEXT:    movl {{[0-9]+}}(%esp), %ecx
-; X86-NOCMOV-NOT:     rep
 ; X86-NOCMOV-NEXT:    bsfl {{[0-9]+}}(%esp), %edx
 ; X86-NOCMOV-NEXT:    movl $32, %eax
 ; X86-NOCMOV-NEXT:    je .LBB15_2
@@ -852,12 +849,10 @@ define i64 @cttz_i64_zero_test(i64 %n) {
 ; X86-CMOV-LABEL: cttz_i64_zero_test:
 ; X86-CMOV:       # %bb.0:
 ; X86-CMOV-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-CMOV-NOT:     rep
 ; X86-CMOV-NEXT:    bsfl {{[0-9]+}}(%esp), %ecx
 ; X86-CMOV-NEXT:    movl $32, %edx
 ; X86-CMOV-NEXT:    cmovnel %ecx, %edx
 ; X86-CMOV-NEXT:    addl $32, %edx
-; X86-CMOV-NOT:     rep
 ; X86-CMOV-NEXT:    bsfl %eax, %eax
 ; X86-CMOV-NEXT:    cmovel %edx, %eax
 ; X86-CMOV-NEXT:    xorl %edx, %edx
@@ -1154,8 +1149,8 @@ define i8 @ctlz_i8_knownbits(i8 %x)  {
 ; X86-CLZ-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
 ; X86-CLZ-NEXT:    orb $64, %al
 ; X86-CLZ-NEXT:    movzbl %al, %eax
+; X86-CLZ-NEXT:    shll $24, %eax
 ; X86-CLZ-NEXT:    lzcntl %eax, %eax
-; X86-CLZ-NEXT:    addl $-24, %eax
 ; X86-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-CLZ-NEXT:    retl
 ;
@@ -1163,8 +1158,8 @@ define i8 @ctlz_i8_knownbits(i8 %x)  {
 ; X64-CLZ:       # %bb.0:
 ; X64-CLZ-NEXT:    orb $64, %dil
 ; X64-CLZ-NEXT:    movzbl %dil, %eax
+; X64-CLZ-NEXT:    shll $24, %eax
 ; X64-CLZ-NEXT:    lzcntl %eax, %eax
-; X64-CLZ-NEXT:    addl $-24, %eax
 ; X64-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-CLZ-NEXT:    retq
 ;
@@ -1172,8 +1167,8 @@ define i8 @ctlz_i8_knownbits(i8 %x)  {
 ; X64-FASTLZCNT:       # %bb.0:
 ; X64-FASTLZCNT-NEXT:    orb $64, %dil
 ; X64-FASTLZCNT-NEXT:    movzbl %dil, %eax
+; X64-FASTLZCNT-NEXT:    shll $24, %eax
 ; X64-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X64-FASTLZCNT-NEXT:    addl $-24, %eax
 ; X64-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-FASTLZCNT-NEXT:    retq
 ;
@@ -1182,8 +1177,8 @@ define i8 @ctlz_i8_knownbits(i8 %x)  {
 ; X86-FASTLZCNT-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
 ; X86-FASTLZCNT-NEXT:    orb $64, %al
 ; X86-FASTLZCNT-NEXT:    movzbl %al, %eax
+; X86-FASTLZCNT-NEXT:    shll $24, %eax
 ; X86-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X86-FASTLZCNT-NEXT:    addl $-24, %eax
 ; X86-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-FASTLZCNT-NEXT:    retl
 
@@ -1481,13 +1476,11 @@ define i32 @PR47603_zext(i32 %a0, ptr %a1) {
 define i32 @cttz_i32_osize(i32 %x) optsize {
 ; X86-LABEL: cttz_i32_osize:
 ; X86:       # %bb.0:
-; X86-NOT:     rep
 ; X86-NEXT:    bsfl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: cttz_i32_osize:
 ; X64:       # %bb.0:
-; X64-NOT:     rep
 ; X64-NEXT:    bsfl %edi, %eax
 ; X64-NEXT:    retq
 ;
@@ -1517,13 +1510,11 @@ define i32 @cttz_i32_osize(i32 %x) optsize {
 define i32 @cttz_i32_msize(i32 %x) minsize {
 ; X86-LABEL: cttz_i32_msize:
 ; X86:       # %bb.0:
-; X86-NOT:     rep
 ; X86-NEXT:    bsfl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: cttz_i32_msize:
 ; X64:       # %bb.0:
-; X64-NOT:     rep
 ; X64-NEXT:    bsfl %edi, %eax
 ; X64-NEXT:    retq
 ;
@@ -1581,18 +1572,17 @@ define i8 @ctlz_xor7_i8_true(i8 %x) {
 ;
 ; X64-FASTLZCNT-LABEL: ctlz_xor7_i8_true:
 ; X64-FASTLZCNT:       # %bb.0:
-; X64-FASTLZCNT-NEXT:    movzbl %dil, %eax
-; X64-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X64-FASTLZCNT-NEXT:    addl $-24, %eax
+; X64-FASTLZCNT-NEXT:    shll $24, %edi
+; X64-FASTLZCNT-NEXT:    lzcntl %edi, %eax
 ; X64-FASTLZCNT-NEXT:    xorb $7, %al
 ; X64-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-FASTLZCNT-NEXT:    retq
 ;
 ; X86-FASTLZCNT-LABEL: ctlz_xor7_i8_true:
 ; X86-FASTLZCNT:       # %bb.0:
-; X86-FASTLZCNT-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
+; X86-FASTLZCNT-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-FASTLZCNT-NEXT:    shll $24, %eax
 ; X86-FASTLZCNT-NEXT:    lzcntl %eax, %eax
-; X86-FASTLZCNT-NEXT:    addl $-24, %eax
 ; X86-FASTLZCNT-NEXT:    xorb $7, %al
 ; X86-FASTLZCNT-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-FASTLZCNT-NEXT:    retl
diff --git a/llvm/test/CodeGen/X86/lzcnt.ll b/llvm/test/CodeGen/X86/lzcnt.ll
index 68cef3f9363f99..b0004019734168 100644
--- a/llvm/test/CodeGen/X86/lzcnt.ll
+++ b/llvm/test/CodeGen/X86/lzcnt.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc < %s -mtriple=i686-- -mattr=+lzcnt | FileCheck %s --check-prefix=X86
 ; RUN: llc < %s -mtriple=x86_64-linux-gnux32  -mattr=+lzcnt | FileCheck %s --check-prefix=X32
 ; RUN: llc < %s -mtriple=x86_64-- -mattr=+lzcnt | FileCheck %s --check-prefix=X64
@@ -106,25 +106,23 @@ define i64 @t4(i64 %x) nounwind  {
 define i8 @t5(i8 %x) nounwind  {
 ; X86-LABEL: t5:
 ; X86:       # %bb.0:
-; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shll $24, %eax
 ; X86-NEXT:    lzcntl %eax, %eax
-; X86-NEXT:    addl $-24, %eax
 ; X86-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-NEXT:    retl
 ;
 ; X32-LABEL: t5:
 ; X32:       # %bb.0:
-; X32-NEXT:    movzbl %dil, %eax
-; X32-NEXT:    lzcntl %eax, %eax
-; X32-NEXT:    addl $-24, %eax
+; X32-NEXT:    shll $24, %edi
+; X32-NEXT:    lzcntl %edi, %eax
 ; X32-NEXT:    # kill: def $al killed $al killed $eax
 ; X32-NEXT:    retq
 ;
 ; X64-LABEL: t5:
 ; X64:       # %bb.0:
-; X64-NEXT:    movzbl %dil, %eax
-; X64-NEXT:    lzcntl %eax, %eax
-; X64-NEXT:    addl $-24, %eax
+; X64-NEXT:    shll $24, %edi
+; X64-NEXT:    lzcntl %edi, %eax
 ; X64-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-NEXT:    retq
 	%tmp = tail call i8 @llvm.ctlz.i8( i8 %x, i1 true )

``````````

</details>


https://github.com/llvm/llvm-project/pull/83039


More information about the llvm-commits mailing list