[llvm] [LegalizeDAG] Optimize CodeGen for `ISD::CTLZ_ZERO_UNDEF` (PR #83039)

Manish Kausik H via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 31 10:30:09 PDT 2024


================
@@ -613,21 +613,42 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
     }
   }
 
-  // Zero extend to the promoted type and do the count there.
-  SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+  unsigned CtlzOpcode = N->getOpcode();
+  if (CtlzOpcode == ISD::CTLZ) {
+    // Zero extend to the promoted type and do the count there.
+    SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+
+    // Subtract off the extra leading bits in the bigger type.
+    SDValue ExtractLeadingBits = DAG.getConstant(
+        NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
+    if (!N->isVPOpcode())
+      return DAG.getNode(ISD::SUB, dl, NVT,
+                         DAG.getNode(N->getOpcode(), dl, NVT, Op),
+                         ExtractLeadingBits);
+    SDValue Mask = N->getOperand(1);
+    SDValue EVL = N->getOperand(2);
+    return DAG.getNode(ISD::VP_SUB, dl, NVT,
+                       DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
+                       ExtractLeadingBits, Mask, EVL);
+  } else if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF) {
+    // Zero Extend the argument
+    SDValue Op = ZExtPromotedInteger(N->getOperand(0));
+
+    // Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
+    auto ShiftConst =
+        DAG.getConstant(NVT.getSizeInBits() - OVT.getSizeInBits(), dl, NVT);
+    if (!N->isVPOpcode()) {
+      Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst);
+      return DAG.getNode(CtlzOpcode, dl, NVT, Op);
+    }
 
-  // Subtract off the extra leading bits in the bigger type.
-  SDValue ExtractLeadingBits = DAG.getConstant(
-      NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
-  if (!N->isVPOpcode())
-    return DAG.getNode(ISD::SUB, dl, NVT,
-                       DAG.getNode(N->getOpcode(), dl, NVT, Op),
-                       ExtractLeadingBits);
-  SDValue Mask = N->getOperand(1);
-  SDValue EVL = N->getOperand(2);
-  return DAG.getNode(ISD::VP_SUB, dl, NVT,
-                     DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
-                     ExtractLeadingBits, Mask, EVL);
+    SDValue Mask = N->getOperand(1);
+    SDValue EVL = N->getOperand(2);
+    Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst, Mask, EVL);
----------------
Nirhar wrote:

@topperc I modified it to ISD::VP_SHL. However, how do I test it? Do I write a test that uses masked vector loads, and call the ctlz intrinsic on it?

https://github.com/llvm/llvm-project/pull/83039


More information about the llvm-commits mailing list