[llvm] Optimize count leading ones if promoted type (PR #99591)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 21:29:47 PDT 2024


https://github.com/v01dXYZ updated https://github.com/llvm/llvm-project/pull/99591

>From 27075abea73c78ee7e78cd9bfe32e719e7a63d1b Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Tue, 9 Jul 2024 16:42:39 +0200
Subject: [PATCH 1/4] Optimisation for count leading ones when promotion is
 necessary

(CTLZ (XOR Op -1)) --> (CTLZ_ZERO_UNDEF (XOR (SHIFT Op ShiftAmount) -1))
---
 llvm/lib/CodeGen/CodeGenPrepare.cpp           |  20 +++-
 .../CodeGen/GlobalISel/LegalizerHelper.cpp    |  36 +++++++
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp |  45 ++++++++
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  38 +++++++
 llvm/test/CodeGen/AArch64/ctlo.ll             | 100 ++++++++++++++++++
 llvm/test/CodeGen/X86/ctlo.ll                 |  26 ++---
 6 files changed, 249 insertions(+), 16 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/ctlo.ll

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 22d0708f54786..9a019b42a60ac 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -2298,8 +2298,26 @@ static bool despeculateCountZeros(IntrinsicInst *CountZeros,
   if (match(CountZeros->getOperand(1), m_One()))
     return false;
 
-  // If it's cheap to speculate, there's nothing to do.
   Type *Ty = CountZeros->getType();
+  EVT VTy = TLI->getValueType(*DL, Ty);
+
+  // do not despeculate if we have (ctlz (xor -1 op)) if the operand is
+  // promoted as legalisation would later transform to:
+  //
+  // (ctlz (lshift (xor -1 (extend op))
+  //               lshiftamount))
+  //
+  // Despeculation is not only useless but also not wanted with SelectionDAG
+  // as XOR and CTLZ would be in different basic blocks.
+  if (ConstantInt * C;
+      (TLI->getTypeConversion(CountZeros->getContext(), VTy).first ==
+           TargetLowering::TypePromoteInteger ||
+       TLI->getOperationAction(ISD::CTLZ, VTy) == TargetLowering::Promote) &&
+      match(CountZeros->getOperand(0), m_Xor(m_Value(), m_ConstantInt(C))) &&
+      C->isMinusOne())
+    return false;
+
+  // If it's cheap to speculate, there's nothing to do.
   auto IntrinsicID = CountZeros->getIntrinsicID();
   if ((IntrinsicID == Intrinsic::cttz && TLI->isCheapToSpeculateCttz(Ty)) ||
       (IntrinsicID == Intrinsic::ctlz && TLI->isCheapToSpeculateCtlz(Ty)))
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 3f1094e0ac703..a334b14e0cc8b 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -2356,6 +2356,35 @@ LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
   return Legalized;
 }
 
+static bool extendCtlzNot(const MachineInstr &MI, MachineIRBuilder &MIRBuilder,
+                          MachineRegisterInfo &MRI, LLT WideTy) {
+  Register XorSrc;
+  Register CstReg;
+  if (!mi_match(MI.getOperand(1).getReg(), MRI,
+                m_GXor(m_Reg(XorSrc), m_Reg(CstReg))))
+    return false;
+
+  auto OptCst = getIConstantVRegValWithLookThrough(CstReg, MRI);
+  APInt Cst = OptCst->Value;
+
+  if (!Cst.isAllOnes())
+    return false;
+
+  auto AllOnes = MIRBuilder.buildConstant(
+      WideTy, APInt::getAllOnes(WideTy.getSizeInBits()));
+  auto Res = MIRBuilder.buildAnyExt(WideTy, XorSrc);
+
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT CurTy = MRI.getType(SrcReg);
+  unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
+  Res = MIRBuilder.buildShl(WideTy, Res,
+                            MIRBuilder.buildConstant(WideTy, SizeDiff));
+  Res = MIRBuilder.buildXor(WideTy, Res, AllOnes);
+  Res = MIRBuilder.buildCTLZ_ZERO_UNDEF(MI.getOperand(0), Res);
+
+  return true;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
   switch (MI.getOpcode()) {
@@ -2449,6 +2478,13 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
     auto MIBSrc = MIRBuilder.buildInstr(ExtOpc, {WideTy}, {SrcReg});
     LLT CurTy = MRI.getType(SrcReg);
     unsigned NewOpc = MI.getOpcode();
+
+    if ((MI.getOpcode() == TargetOpcode::G_CTLZ ||
+         MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) &&
+        extendCtlzNot(MI, MIRBuilder, MRI, WideTy)) {
+      MI.eraseFromParent();
+      return Legalized;
+    }
     if (NewOpc == TargetOpcode::G_CTTZ) {
       // The count is the same in the larger type except if the original
       // value was zero.  This can be handled by setting the bit just off
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index d6a0dd9ae9b20..f062ad2543dd9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -5049,6 +5049,40 @@ static MVT getPromotedVectorElementType(const TargetLowering &TLI,
   return MidVT;
 }
 
+// (CTLZ (XOR Op -1)) --> (TRUNCATE (CTLZ_ZERO_UNDEF
+//                                    (XOR (SHIFT (ANYEXTEND Op1)
+//                                                ShiftAmount)
+//                                         -1)))
+static bool ExtendCtlzNot(SDNode *Node, SDValue &Result, SDLoc &dl, MVT OVT,
+                          MVT NVT, SelectionDAG &DAG) {
+  SDValue NotOp = Node->getOperand(0);
+  if (NotOp.getOpcode() != ISD::XOR)
+    return false;
+
+  SDValue SrcOp = NotOp->getOperand(0);
+  SDValue CstOp = NotOp->getOperand(1);
+
+  ConstantSDNode *Cst = dyn_cast<ConstantSDNode>(CstOp);
+
+  if (!Cst || !Cst->isAllOnes())
+    return false;
+
+  auto ExtSrc = DAG.getNode(ISD::ANY_EXTEND, dl, NVT, SrcOp);
+  unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
+  auto ShiftConst =
+      DAG.getShiftAmountConstant(SHLAmount, ExtSrc.getValueType(), dl);
+  SDValue NSrcOp = DAG.getNode(ISD::SHL, dl, NVT, ExtSrc, ShiftConst);
+
+  SDValue NCstOp =
+      DAG.getConstant(APInt::getAllOnes(NVT.getScalarSizeInBits()), dl, NVT);
+
+  Result = DAG.getNode(NotOp->getOpcode(), dl, NVT, NSrcOp, NCstOp,
+                       NotOp->getFlags());
+  Result = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, NVT, Result);
+  Result = DAG.getNode(ISD::TRUNCATE, dl, OVT, Result);
+  return true;
+}
+
 void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   LLVM_DEBUG(dbgs() << "Trying to promote node\n");
   SmallVector<SDValue, 8> Results;
@@ -5084,6 +5118,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   case ISD::CTTZ_ZERO_UNDEF:
   case ISD::CTLZ:
   case ISD::CTPOP: {
+    // If the operand of CTLZ is NOT, push the extend in the NOT.
+    if (Node->getOpcode() == ISD::CTLZ &&
+        ExtendCtlzNot(Node, Tmp1, dl, OVT, NVT, DAG)) {
+      Results.push_back(Tmp1);
+      break;
+    }
+
     // Zero extend the argument unless its cttz, then use any_extend.
     if (Node->getOpcode() == ISD::CTTZ ||
         Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF)
@@ -5115,6 +5156,10 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     break;
   }
   case ISD::CTLZ_ZERO_UNDEF: {
+    if (ExtendCtlzNot(Node, Tmp1, dl, OVT, NVT, DAG)) {
+      Results.push_back(Tmp1);
+      break;
+    }
     // We know that the argument is unlikely to be zero, hence we can take a
     // different approach as compared to ISD::CTLZ
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fed5ebcc3c903..58e519c8657e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -638,6 +638,37 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Constant(SDNode *N) {
   return Result;
 }
 
+// (CTLZ (XOR Op -1)) --> (CTLZ_ZERO_UNDEF (XOR (SHIFT (ANYEXTEND Op1)
+//                                                     ShiftAmount)
+//                                               -1))
+static bool ExtendCtlzNot(SDNode *Node, SDValue &Result, SDLoc &dl, EVT OVT,
+                          EVT NVT, SelectionDAG &DAG) {
+  SDValue NotOp = Node->getOperand(0);
+  if (NotOp.getOpcode() != ISD::XOR)
+    return false;
+
+  SDValue SrcOp = NotOp->getOperand(0);
+  SDValue CstOp = NotOp->getOperand(1);
+
+  ConstantSDNode *Cst = dyn_cast<ConstantSDNode>(CstOp);
+  if (!Cst || !Cst->isAllOnes())
+    return false;
+
+  auto ExtSrc = DAG.getNode(ISD::ANY_EXTEND, dl, NVT, SrcOp);
+  unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
+  auto ShiftConst =
+      DAG.getShiftAmountConstant(SHLAmount, ExtSrc.getValueType(), dl);
+  SDValue NSrcOp = DAG.getNode(ISD::SHL, dl, NVT, ExtSrc, ShiftConst);
+
+  SDValue NCstOp =
+      DAG.getConstant(APInt::getAllOnes(NVT.getScalarSizeInBits()), dl, NVT);
+
+  Result = DAG.getNode(NotOp->getOpcode(), dl, NVT, NSrcOp, NCstOp,
+                       NotOp->getFlags());
+  Result = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, NVT, Result);
+  return true;
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   EVT OVT = N->getValueType(0);
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
@@ -656,6 +687,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   }
 
   unsigned CtlzOpcode = N->getOpcode();
+  // If the operand of CTLZ is NOT, push the extend in the NOT.
+  if (SDValue Res;
+      (CtlzOpcode == ISD::CTLZ || CtlzOpcode == ISD::CTLZ_ZERO_UNDEF) &&
+      ExtendCtlzNot(N, Res, dl, OVT, NVT, DAG)) {
+    return Res;
+  }
+
   if (CtlzOpcode == ISD::CTLZ || CtlzOpcode == ISD::VP_CTLZ) {
     // Subtract off the extra leading bits in the bigger type.
     SDValue ExtractLeadingBits = DAG.getConstant(
diff --git a/llvm/test/CodeGen/AArch64/ctlo.ll b/llvm/test/CodeGen/AArch64/ctlo.ll
new file mode 100644
index 0000000000000..5f15f540f458d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/ctlo.ll
@@ -0,0 +1,100 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s --mtriple=aarch64 -verify-machineinstrs | FileCheck %s
+; RUN: llc < %s --mtriple=aarch64 -global-isel -verify-machineinstrs | FileCheck %s
+
+declare i8 @llvm.ctlz.i8(i8, i1)
+declare i16 @llvm.ctlz.i16(i16, i1)
+declare i32 @llvm.ctlz.i32(i32, i1)
+declare i64 @llvm.ctlz.i64(i64, i1)
+
+define i8 @ctlo_i8(i8 %x) {
+; CHECK-LABEL: ctlo_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-1 // =0xffffffff
+; CHECK-NEXT:    eor w8, w8, w0, lsl #24
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i8 %x, -1
+  %tmp2 = call i8 @llvm.ctlz.i8( i8 %tmp1, i1 false )
+  ret i8 %tmp2
+}
+
+define i8 @ctlo_i8_undef(i8 %x) {
+; CHECK-LABEL: ctlo_i8_undef:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-1 // =0xffffffff
+; CHECK-NEXT:    eor w8, w8, w0, lsl #24
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i8 %x, -1
+  %tmp2 = call i8 @llvm.ctlz.i8( i8 %tmp1, i1 true )
+  ret i8 %tmp2
+}
+
+define i16 @ctlo_i16(i16 %x) {
+; CHECK-LABEL: ctlo_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-1 // =0xffffffff
+; CHECK-NEXT:    eor w8, w8, w0, lsl #16
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i16 %x, -1
+  %tmp2 = call i16 @llvm.ctlz.i16( i16 %tmp1, i1 false )
+  ret i16 %tmp2
+}
+
+define i16 @ctlo_i16_undef(i16 %x) {
+; CHECK-LABEL: ctlo_i16_undef:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #-1 // =0xffffffff
+; CHECK-NEXT:    eor w8, w8, w0, lsl #16
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i16 %x, -1
+  %tmp2 = call i16 @llvm.ctlz.i16( i16 %tmp1, i1 true )
+  ret i16 %tmp2
+}
+
+define i32 @ctlo_i32(i32 %x) {
+; CHECK-LABEL: ctlo_i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn w8, w0
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i32 %x, -1
+  %tmp2 = call i32 @llvm.ctlz.i32( i32 %tmp1, i1 false )
+  ret i32 %tmp2
+}
+
+define i32 @ctlo_i32_undef(i32 %x) {
+; CHECK-LABEL: ctlo_i32_undef:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn w8, w0
+; CHECK-NEXT:    clz w0, w8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i32 %x, -1
+  %tmp2 = call i32 @llvm.ctlz.i32( i32 %tmp1, i1 true )
+  ret i32 %tmp2
+}
+
+define i64 @ctlo_i64(i64 %x) {
+; CHECK-LABEL: ctlo_i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn x8, x0
+; CHECK-NEXT:    clz x0, x8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i64 %x, -1
+  %tmp2 = call i64 @llvm.ctlz.i64( i64 %tmp1, i1 false )
+  ret i64 %tmp2
+}
+
+define i64 @ctlo_i64_undef(i64 %x) {
+; CHECK-LABEL: ctlo_i64_undef:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn x8, x0
+; CHECK-NEXT:    clz x0, x8
+; CHECK-NEXT:    ret
+  %tmp1 = xor i64 %x, -1
+  %tmp2 = call i64 @llvm.ctlz.i64( i64 %tmp1, i1 true )
+  ret i64 %tmp2
+}
diff --git a/llvm/test/CodeGen/X86/ctlo.ll b/llvm/test/CodeGen/X86/ctlo.ll
index 7431f94f0fdf2..020d6d1b80136 100644
--- a/llvm/test/CodeGen/X86/ctlo.ll
+++ b/llvm/test/CodeGen/X86/ctlo.ll
@@ -46,20 +46,18 @@ define i8 @ctlo_i8(i8 %x) {
 ;
 ; X86-CLZ-LABEL: ctlo_i8:
 ; X86-CLZ:       # %bb.0:
-; X86-CLZ-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
-; X86-CLZ-NEXT:    notb %al
-; X86-CLZ-NEXT:    movzbl %al, %eax
+; X86-CLZ-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-CLZ-NEXT:    shll $24, %eax
+; X86-CLZ-NEXT:    notl %eax
 ; X86-CLZ-NEXT:    lzcntl %eax, %eax
-; X86-CLZ-NEXT:    addl $-24, %eax
 ; X86-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-CLZ-NEXT:    retl
 ;
 ; X64-CLZ-LABEL: ctlo_i8:
 ; X64-CLZ:       # %bb.0:
-; X64-CLZ-NEXT:    notb %dil
-; X64-CLZ-NEXT:    movzbl %dil, %eax
-; X64-CLZ-NEXT:    lzcntl %eax, %eax
-; X64-CLZ-NEXT:    addl $-24, %eax
+; X64-CLZ-NEXT:    shll $24, %edi
+; X64-CLZ-NEXT:    notl %edi
+; X64-CLZ-NEXT:    lzcntl %edi, %eax
 ; X64-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-CLZ-NEXT:    retq
   %tmp1 = xor i8 %x, -1
@@ -89,20 +87,18 @@ define i8 @ctlo_i8_undef(i8 %x) {
 ;
 ; X86-CLZ-LABEL: ctlo_i8_undef:
 ; X86-CLZ:       # %bb.0:
-; X86-CLZ-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
-; X86-CLZ-NEXT:    notb %al
-; X86-CLZ-NEXT:    movzbl %al, %eax
+; X86-CLZ-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-CLZ-NEXT:    shll $24, %eax
+; X86-CLZ-NEXT:    notl %eax
 ; X86-CLZ-NEXT:    lzcntl %eax, %eax
 ; X86-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X86-CLZ-NEXT:    retl
 ;
 ; X64-CLZ-LABEL: ctlo_i8_undef:
 ; X64-CLZ:       # %bb.0:
-; X64-CLZ-NEXT:    notb %dil
-; X64-CLZ-NEXT:    movzbl %dil, %eax
-; X64-CLZ-NEXT:    shll $24, %eax
-; X64-CLZ-NEXT:    lzcntl %eax, %eax
+; X64-CLZ-NEXT:    shll $24, %edi
+; X64-CLZ-NEXT:    notl %edi
+; X64-CLZ-NEXT:    lzcntl %edi, %eax
 ; X64-CLZ-NEXT:    # kill: def $al killed $al killed $eax
 ; X64-CLZ-NEXT:    retq
   %tmp1 = xor i8 %x, -1

>From c939054f9deac10e4bb5a5a47f21d91572dbf000 Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Thu, 18 Jul 2024 23:32:11 +0200
Subject: [PATCH 2/4] (XFAIL) Make AMDGPU uses SelectDAG Legaliser for CTLZ

LegaliseType does not allow to override promotion for a pair of
operation and type.

the failing tests is CodeGen/AMDGPU/ctlz_zero_undef.ll.
---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 25 ++++++++++++++++---
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    | 11 +++++---
 llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp |  2 +-
 3 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 58e519c8657e0..5a0379a4dce9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -674,6 +674,17 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
   SDLoc dl(N);
 
+  bool IsDefaultPromotion =
+      TLI.getOperationAction(N->getOpcode(), OVT) != TargetLowering::Promote;
+  if (!IsDefaultPromotion) {
+    EVT NVTPromote = TLI.getTypeToPromoteTo(N->getOpcode(), OVT.getSimpleVT());
+
+    if (NVT == NVTPromote)
+      IsDefaultPromotion = false;
+    else
+      NVT = NVTPromote;
+  }
+
   // If the larger CTLZ isn't supported by the target, try to expand now.
   // If we expand later we'll end up with more operations since we lost the
   // original type.
@@ -701,7 +712,10 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
 
     if (!N->isVPOpcode()) {
       // Zero extend to the promoted type and do the count there.
-      SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+      SDValue Op = IsDefaultPromotion
+                       ? ZExtPromotedInteger(N->getOperand(0))
+                       : DAG.getZExtOrTrunc(N->getOperand(0), dl, NVT);
+
       return DAG.getNode(ISD::SUB, dl, NVT,
                          DAG.getNode(N->getOpcode(), dl, NVT, Op),
                          ExtractLeadingBits);
@@ -709,7 +723,10 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
     SDValue Mask = N->getOperand(1);
     SDValue EVL = N->getOperand(2);
     // Zero extend to the promoted type and do the count there.
-    SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
+    SDValue Op =
+        IsDefaultPromotion
+            ? VPZExtPromotedInteger(N->getOperand(0), Mask, EVL)
+            : DAG.getVPZExtOrTrunc(dl, NVT, N->getOperand(0), Mask, EVL);
     return DAG.getNode(ISD::VP_SUB, dl, NVT,
                        DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
                        ExtractLeadingBits, Mask, EVL);
@@ -717,7 +734,9 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF ||
       CtlzOpcode == ISD::VP_CTLZ_ZERO_UNDEF) {
     // Any Extend the argument
-    SDValue Op = GetPromotedInteger(N->getOperand(0));
+    SDValue Op = IsDefaultPromotion
+                     ? GetPromotedInteger(N->getOperand(0))
+                     : DAG.getAnyExtOrTrunc(N->getOperand(0), dl, NVT);
     // Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
     unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
     auto ShiftConst =
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index cb6d3fe4db8a4..699c45dfa5a4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -705,9 +705,14 @@ void DAGTypeLegalizer::ReplaceValueWith(SDValue From, SDValue To) {
 }
 
 void DAGTypeLegalizer::SetPromotedInteger(SDValue Op, SDValue Result) {
-  assert(Result.getValueType() ==
-         TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType()) &&
-         "Invalid type for promoted integer");
+  assert(
+      Result.getValueType() ==
+          ((TLI.getOperationAction(Op.getOpcode(), Op.getValueType()) !=
+            TargetLowering::Promote)
+               ? TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType())
+               : TLI.getTypeToPromoteTo(Op.getOpcode(),
+                                        Op.getSimpleValueType())) &&
+      "Invalid type for promoted integer");
   AnalyzeNewValue(Result);
 
   auto &OpIdEntry = PromotedIntegers[getTableId(Op)];
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index ef30bf6d993fa..5a98329cea8c9 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -498,7 +498,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
       MVT::i64, Custom);
 
   for (auto VT : {MVT::i8, MVT::i16})
-    setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT, Custom);
+    setOperationPromotedToType({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT, MVT::i32);
 
   static const MVT::SimpleValueType VectorIntTypes[] = {
       MVT::v2i32, MVT::v3i32, MVT::v4i32, MVT::v5i32, MVT::v6i32, MVT::v7i32,

>From 04e67386d1fd68870b5cbd424c123ddc75f01387 Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Fri, 19 Jul 2024 02:12:17 +0200
Subject: [PATCH 3/4] Revert "(XFAIL) Make AMDGPU uses SelectDAG Legaliser for
 CTLZ"

This reverts commit 0a4ba146fd84a05489cdd2c91cbc59297bc8a03b.
---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 25 +++----------------
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    | 11 +++-----
 llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp |  2 +-
 3 files changed, 7 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 5a0379a4dce9d..58e519c8657e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -674,17 +674,6 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
   SDLoc dl(N);
 
-  bool IsDefaultPromotion =
-      TLI.getOperationAction(N->getOpcode(), OVT) != TargetLowering::Promote;
-  if (!IsDefaultPromotion) {
-    EVT NVTPromote = TLI.getTypeToPromoteTo(N->getOpcode(), OVT.getSimpleVT());
-
-    if (NVT == NVTPromote)
-      IsDefaultPromotion = false;
-    else
-      NVT = NVTPromote;
-  }
-
   // If the larger CTLZ isn't supported by the target, try to expand now.
   // If we expand later we'll end up with more operations since we lost the
   // original type.
@@ -712,10 +701,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
 
     if (!N->isVPOpcode()) {
       // Zero extend to the promoted type and do the count there.
-      SDValue Op = IsDefaultPromotion
-                       ? ZExtPromotedInteger(N->getOperand(0))
-                       : DAG.getZExtOrTrunc(N->getOperand(0), dl, NVT);
-
+      SDValue Op = ZExtPromotedInteger(N->getOperand(0));
       return DAG.getNode(ISD::SUB, dl, NVT,
                          DAG.getNode(N->getOpcode(), dl, NVT, Op),
                          ExtractLeadingBits);
@@ -723,10 +709,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
     SDValue Mask = N->getOperand(1);
     SDValue EVL = N->getOperand(2);
     // Zero extend to the promoted type and do the count there.
-    SDValue Op =
-        IsDefaultPromotion
-            ? VPZExtPromotedInteger(N->getOperand(0), Mask, EVL)
-            : DAG.getVPZExtOrTrunc(dl, NVT, N->getOperand(0), Mask, EVL);
+    SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
     return DAG.getNode(ISD::VP_SUB, dl, NVT,
                        DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
                        ExtractLeadingBits, Mask, EVL);
@@ -734,9 +717,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
   if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF ||
       CtlzOpcode == ISD::VP_CTLZ_ZERO_UNDEF) {
     // Any Extend the argument
-    SDValue Op = IsDefaultPromotion
-                     ? GetPromotedInteger(N->getOperand(0))
-                     : DAG.getAnyExtOrTrunc(N->getOperand(0), dl, NVT);
+    SDValue Op = GetPromotedInteger(N->getOperand(0));
     // Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
     unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
     auto ShiftConst =
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 699c45dfa5a4d..cb6d3fe4db8a4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -705,14 +705,9 @@ void DAGTypeLegalizer::ReplaceValueWith(SDValue From, SDValue To) {
 }
 
 void DAGTypeLegalizer::SetPromotedInteger(SDValue Op, SDValue Result) {
-  assert(
-      Result.getValueType() ==
-          ((TLI.getOperationAction(Op.getOpcode(), Op.getValueType()) !=
-            TargetLowering::Promote)
-               ? TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType())
-               : TLI.getTypeToPromoteTo(Op.getOpcode(),
-                                        Op.getSimpleValueType())) &&
-      "Invalid type for promoted integer");
+  assert(Result.getValueType() ==
+         TLI.getTypeToTransformTo(*DAG.getContext(), Op.getValueType()) &&
+         "Invalid type for promoted integer");
   AnalyzeNewValue(Result);
 
   auto &OpIdEntry = PromotedIntegers[getTableId(Op)];
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 5a98329cea8c9..ef30bf6d993fa 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -498,7 +498,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
       MVT::i64, Custom);
 
   for (auto VT : {MVT::i8, MVT::i16})
-    setOperationPromotedToType({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT, MVT::i32);
+    setOperationAction({ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT, Custom);
 
   static const MVT::SimpleValueType VectorIntTypes[] = {
       MVT::v2i32, MVT::v3i32, MVT::v4i32, MVT::v5i32, MVT::v6i32, MVT::v7i32,

>From 7b9d5b77df71ff868ad44e8002aa9d2afa381ebe Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Fri, 19 Jul 2024 06:29:14 +0200
Subject: [PATCH 4/4] Update LoongArch ctlz tests

---
 .../test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll | 24 +++++++------------
 1 file changed, 8 insertions(+), 16 deletions(-)

diff --git a/llvm/test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll b/llvm/test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll
index f17cec231f323..e993ecfcdf3b8 100644
--- a/llvm/test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll
+++ b/llvm/test/CodeGen/LoongArch/ctlz-cttz-ctpop.ll
@@ -89,18 +89,14 @@ define i64 @test_ctlz_i64(i64 %a) nounwind {
 define i8 @test_not_ctlz_i8(i8 %a) nounwind {
 ; LA32-LABEL: test_not_ctlz_i8:
 ; LA32:       # %bb.0:
-; LA32-NEXT:    ori $a1, $zero, 255
-; LA32-NEXT:    andn $a0, $a1, $a0
-; LA32-NEXT:    clz.w $a0, $a0
-; LA32-NEXT:    addi.w $a0, $a0, -24
+; LA32-NEXT:    slli.w $a0, $a0, 24
+; LA32-NEXT:    clo.w $a0, $a0
 ; LA32-NEXT:    ret
 ;
 ; LA64-LABEL: test_not_ctlz_i8:
 ; LA64:       # %bb.0:
-; LA64-NEXT:    ori $a1, $zero, 255
-; LA64-NEXT:    andn $a0, $a1, $a0
-; LA64-NEXT:    clz.d $a0, $a0
-; LA64-NEXT:    addi.d $a0, $a0, -56
+; LA64-NEXT:    slli.d $a0, $a0, 56
+; LA64-NEXT:    clo.d $a0, $a0
 ; LA64-NEXT:    ret
   %neg = xor i8 %a, -1
   %tmp = call i8 @llvm.ctlz.i8(i8 %neg, i1 false)
@@ -110,18 +106,14 @@ define i8 @test_not_ctlz_i8(i8 %a) nounwind {
 define i16 @test_not_ctlz_i16(i16 %a) nounwind {
 ; LA32-LABEL: test_not_ctlz_i16:
 ; LA32:       # %bb.0:
-; LA32-NEXT:    nor $a0, $a0, $zero
-; LA32-NEXT:    bstrpick.w $a0, $a0, 15, 0
-; LA32-NEXT:    clz.w $a0, $a0
-; LA32-NEXT:    addi.w $a0, $a0, -16
+; LA32-NEXT:    slli.w $a0, $a0, 16
+; LA32-NEXT:    clo.w $a0, $a0
 ; LA32-NEXT:    ret
 ;
 ; LA64-LABEL: test_not_ctlz_i16:
 ; LA64:       # %bb.0:
-; LA64-NEXT:    nor $a0, $a0, $zero
-; LA64-NEXT:    bstrpick.d $a0, $a0, 15, 0
-; LA64-NEXT:    clz.d $a0, $a0
-; LA64-NEXT:    addi.d $a0, $a0, -48
+; LA64-NEXT:    slli.d $a0, $a0, 48
+; LA64-NEXT:    clo.d $a0, $a0
 ; LA64-NEXT:    ret
   %neg = xor i16 %a, -1
   %tmp = call i16 @llvm.ctlz.i16(i16 %neg, i1 false)



More information about the llvm-commits mailing list