[llvm] 0a42994 - [GlobalISel] Fold G_CTTZ if possible (#86224)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 13:55:41 PDT 2024


Author: Shilei Tian
Date: 2024-03-25T16:55:37-04:00
New Revision: 0a4299403e66871f88518f7399d511b72c634c1d

URL: https://github.com/llvm/llvm-project/commit/0a4299403e66871f88518f7399d511b72c634c1d
DIFF: https://github.com/llvm/llvm-project/commit/0a4299403e66871f88518f7399d511b72c634c1d.diff

LOG: [GlobalISel] Fold G_CTTZ if possible (#86224)

This patch tries to fold `G_CTTZ` if possible.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/Utils.h
    llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/unittests/CodeGen/GlobalISel/CSETest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index f8900f3434ccaa..0241ec4f2111d0 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -308,10 +308,12 @@ std::optional<APFloat> ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy,
                                               Register Src,
                                               const MachineRegisterInfo &MRI);
 
-/// Tries to constant fold a G_CTLZ operation on \p Src. If \p Src is a vector
-/// then it tries to do an element-wise constant fold.
+/// Tries to constant fold a counting-zero operation (G_CTLZ or G_CTTZ) on \p
+/// Src. If \p Src is a vector then it tries to do an element-wise constant
+/// fold.
 std::optional<SmallVector<unsigned>>
-ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI);
+ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
+                       std::function<unsigned(APInt)> CB);
 
 /// Test if the given value is known to have exactly one bit set. This 
diff ers
 /// from computeKnownBits in that it doesn't necessarily determine which bit is

diff  --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index 1869e0d41a51f6..a0bc325c6cda7b 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -256,10 +256,16 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
       return buildFConstant(DstOps[0], *Cst);
     break;
   }
-  case TargetOpcode::G_CTLZ: {
+  case TargetOpcode::G_CTLZ:
+  case TargetOpcode::G_CTTZ: {
     assert(SrcOps.size() == 1 && "Expected one source");
     assert(DstOps.size() == 1 && "Expected one dest");
-    auto MaybeCsts = ConstantFoldCTLZ(SrcOps[0].getReg(), *getMRI());
+    std::function<unsigned(APInt)> CB;
+    if (Opc == TargetOpcode::G_CTLZ)
+      CB = [](APInt V) -> unsigned { return V.countl_zero(); };
+    else
+      CB = [](APInt V) -> unsigned { return V.countTrailingZeros(); };
+    auto MaybeCsts = ConstantFoldCountZeros(SrcOps[0].getReg(), *getMRI(), CB);
     if (!MaybeCsts)
       break;
     if (MaybeCsts->size() == 1)

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index a9fa73b60a097f..8c41f8b1bdcdbc 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -966,14 +966,15 @@ llvm::ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy, Register Src,
 }
 
 std::optional<SmallVector<unsigned>>
-llvm::ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI) {
+llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
+                             std::function<unsigned(APInt)> CB) {
   LLT Ty = MRI.getType(Src);
   SmallVector<unsigned> FoldedCTLZs;
   auto tryFoldScalar = [&](Register R) -> std::optional<unsigned> {
     auto MaybeCst = getIConstantVRegVal(R, MRI);
     if (!MaybeCst)
       return std::nullopt;
-    return MaybeCst->countl_zero();
+    return CB(*MaybeCst);
   };
   if (Ty.isVector()) {
     // Try to constant fold each element.

diff  --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
index 116099eff14aa9..08857de3cf4e44 100644
--- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
@@ -233,4 +233,46 @@ TEST_F(AArch64GISelMITest, TestConstantFoldCTL) {
   EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
 }
 
+TEST_F(AArch64GISelMITest, TestConstantFoldCTT) {
+  setUp();
+  if (!TM)
+    GTEST_SKIP();
+
+  LLT s32 = LLT::scalar(32);
+
+  GISelCSEInfo CSEInfo;
+  CSEInfo.setCSEConfig(std::make_unique<CSEConfigConstantOnly>());
+  CSEInfo.analyze(*MF);
+  B.setCSEInfo(&CSEInfo);
+  CSEMIRBuilder CSEB(B.getState());
+  auto Cst8 = CSEB.buildConstant(s32, 8);
+  auto *CttzDef = &*CSEB.buildCTTZ(s32, Cst8);
+  EXPECT_TRUE(CttzDef->getOpcode() == TargetOpcode::G_CONSTANT);
+  EXPECT_TRUE(CttzDef->getOperand(1).getCImm()->getZExtValue() == 3);
+
+  // Test vector.
+  auto Cst16 = CSEB.buildConstant(s32, 16);
+  auto Cst32 = CSEB.buildConstant(s32, 32);
+  auto Cst64 = CSEB.buildConstant(s32, 64);
+  LLT VecTy = LLT::fixed_vector(4, s32);
+  auto BV = CSEB.buildBuildVector(VecTy, {Cst8.getReg(0), Cst16.getReg(0),
+                                          Cst32.getReg(0), Cst64.getReg(0)});
+  CSEB.buildCTTZ(VecTy, BV);
+
+  auto CheckStr = R"(
+  ; CHECK: [[CST8:%[0-9]+]]:_(s32) = G_CONSTANT i32 8
+  ; CHECK: [[CST3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
+  ; CHECK: [[CST16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16
+  ; CHECK: [[CST32:%[0-9]+]]:_(s32) = G_CONSTANT i32 32
+  ; CHECK: [[CST64:%[0-9]+]]:_(s32) = G_CONSTANT i32 64
+  ; CHECK: [[BV1:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[CST8]]:_(s32), [[CST16]]:_(s32), [[CST32]]:_(s32), [[CST64]]:_(s32)
+  ; CHECK: [[CST27:%[0-9]+]]:_(s32) = G_CONSTANT i32 4
+  ; CHECK: [[CST26:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
+  ; CHECK: [[CST25:%[0-9]+]]:_(s32) = G_CONSTANT i32 6
+  ; CHECK: [[BV2:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[CST3]]:_(s32), [[CST27]]:_(s32), [[CST26]]:_(s32), [[CST25]]:_(s32)
+  )";
+
+  EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
+}
+
 } // namespace


        


More information about the llvm-commits mailing list