[llvm] 6ac4fe8 - [X86] X86FixupVectorConstants.cpp - refactor constant search loop to take array of sorted candidates

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 08:07:00 PST 2024


Author: Simon Pilgrim
Date: 2024-02-01T16:06:36Z
New Revision: 6ac4fe8de014336ce66d02ddd07e85db3b8e77a2

URL: https://github.com/llvm/llvm-project/commit/6ac4fe8de014336ce66d02ddd07e85db3b8e77a2
DIFF: https://github.com/llvm/llvm-project/commit/6ac4fe8de014336ce66d02ddd07e85db3b8e77a2.diff

LOG: [X86] X86FixupVectorConstants.cpp - refactor constant search loop to take array of sorted candidates

Pulled out of #79815 - refactors the internal FixupConstant logic to just accept an array of vzload/broadcast candidates that are pre-sorted in ascending constant pool size

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 037a745d632fb..be3c4f0b1564c 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -216,8 +216,8 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
 
 // Attempt to rebuild a normalized splat vector constant of the requested splat
 // width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatableConstant(const Constant *C,
-                                          unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
+                                 unsigned SplatBitWidth) {
   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
   if (!Splat)
     return nullptr;
@@ -238,8 +238,8 @@ static Constant *rebuildSplatableConstant(const Constant *C,
   return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
 }
 
-static Constant *rebuildZeroUpperConstant(const Constant *C,
-                                          unsigned ScalarBitWidth) {
+static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
+                                     unsigned ScalarBitWidth) {
   Type *Ty = C->getType();
   Type *SclTy = Ty->getScalarType();
   unsigned NumBits = Ty->getPrimitiveSizeInBits();
@@ -265,8 +265,6 @@ static Constant *rebuildZeroUpperConstant(const Constant *C,
   return nullptr;
 }
 
-typedef std::function<Constant *(const Constant *, unsigned)> RebuildFn;
-
 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
                                                      MachineBasicBlock &MBB,
                                                      MachineInstr &MI) {
@@ -277,43 +275,42 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
   bool HasBWI = ST->hasBWI();
   bool HasVLX = ST->hasVLX();
 
-  auto FixupConstant =
-      [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
-          unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
-          unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
-        assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
-               "Unexpected number of operands!");
-
-        if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
-          // Attempt to detect a suitable splat/vzload from increasing constant
-          // bitwidths.
-          // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
-          std::tuple<unsigned, unsigned, RebuildFn> FixupLoad[] = {
-              {8, OpBcst8, rebuildSplatableConstant},
-              {16, OpBcst16, rebuildSplatableConstant},
-              {32, OpUpper32, rebuildZeroUpperConstant},
-              {32, OpBcst32, rebuildSplatableConstant},
-              {64, OpUpper64, rebuildZeroUpperConstant},
-              {64, OpBcst64, rebuildSplatableConstant},
-              {128, OpBcst128, rebuildSplatableConstant},
-              {256, OpBcst256, rebuildSplatableConstant},
-          };
-          for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
-            if (Op) {
-              // Construct a suitable constant and adjust the MI to use the new
-              // constant pool entry.
-              if (Constant *NewCst = RebuildConstant(C, BitWidth)) {
-                unsigned NewCPI =
-                    CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
-                MI.setDesc(TII->get(Op));
-                MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
-                return true;
-              }
-            }
+  struct FixupEntry {
+    int Op;
+    int NumCstElts;
+    int BitWidth;
+    std::function<Constant *(const Constant *, unsigned, unsigned)>
+        RebuildConstant;
+  };
+  auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
+#ifdef EXPENSIVE_CHECKS
+    assert(llvm::is_sorted(Fixups,
+                           [](const FixupEntry &A, const FixupEntry &B) {
+                             return (A.NumCstElts * A.BitWidth) <
+                                    (B.NumCstElts * B.BitWidth);
+                           }) &&
+           "Constant fixup table not sorted in ascending constant size");
+#endif
+    assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
+           "Unexpected number of operands!");
+    if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+      for (const FixupEntry &Fixup : Fixups) {
+        if (Fixup.Op) {
+          // Construct a suitable constant and adjust the MI to use the new
+          // constant pool entry.
+          if (Constant *NewCst =
+                  Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
+            unsigned NewCPI =
+                CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
+            MI.setDesc(TII->get(Fixup.Op));
+            MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
+            return true;
           }
         }
-        return false;
-      };
+      }
+    }
+    return false;
+  };
 
   // Attempt to convert full width vector loads into broadcast/vzload loads.
   switch (Opc) {
@@ -323,82 +320,125 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
   case X86::MOVUPDrm:
   case X86::MOVUPSrm:
     // TODO: SSE3 MOVDDUP Handling
-    return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVSDrm, X86::MOVSSrm, 1);
+    return FixupConstant({{X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
+                          {X86::MOVSDrm, 1, 64, rebuildZeroUpperCst}},
+                         1);
   case X86::VMOVAPDrm:
   case X86::VMOVAPSrm:
   case X86::VMOVUPDrm:
   case X86::VMOVUPSrm:
-    return FixupConstant(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
-                         X86::VMOVSDrm, X86::VMOVSSrm, 1);
+    return FixupConstant({{X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
+                          {X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
+                          {X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
+                          {X86::VMOVDDUPrm, 1, 64, rebuildSplatCst}},
+                         1);
   case X86::VMOVAPDYrm:
   case X86::VMOVAPSYrm:
   case X86::VMOVUPDYrm:
   case X86::VMOVUPSYrm:
-    return FixupConstant(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
-                         X86::VBROADCASTSSYrm, 0, 0, 0, 0, 1);
+    return FixupConstant({{X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
+                          {X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
+                          {X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst}},
+                         1);
   case X86::VMOVAPDZ128rm:
   case X86::VMOVAPSZ128rm:
   case X86::VMOVUPDZ128rm:
   case X86::VMOVUPSZ128rm:
-    return FixupConstant(0, 0, X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0,
-                         0, X86::VMOVSDZrm, X86::VMOVSSZrm, 1);
+    return FixupConstant({{X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
+                          {X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
+                          {X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
+                          {X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst}},
+                         1);
   case X86::VMOVAPDZ256rm:
   case X86::VMOVAPSZ256rm:
   case X86::VMOVUPDZ256rm:
   case X86::VMOVUPSZ256rm:
-    return FixupConstant(0, X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
-                         X86::VBROADCASTSSZ256rm, 0, 0, 0, 0, 1);
+    return FixupConstant(
+        {{X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
+         {X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
+         {X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst}},
+        1);
   case X86::VMOVAPDZrm:
   case X86::VMOVAPSZrm:
   case X86::VMOVUPDZrm:
   case X86::VMOVUPSZrm:
-    return FixupConstant(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
-                         X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 0, 0,
+    return FixupConstant({{X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
+                          {X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
+                          {X86::VBROADCASTF32X4rm, 1, 128, rebuildSplatCst},
+                          {X86::VBROADCASTF64X4rm, 1, 256, rebuildSplatCst}},
                          1);
     /* Integer Loads */
   case X86::MOVDQArm:
-  case X86::MOVDQUrm:
-    return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
+  case X86::MOVDQUrm: {
+    return FixupConstant({{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
+                          {X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst}},
                          1);
+  }
   case X86::VMOVDQArm:
-  case X86::VMOVDQUrm:
-    return FixupConstant(0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
-                         HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
-                         HasAVX2 ? X86::VPBROADCASTWrm : 0,
-                         HasAVX2 ? X86::VPBROADCASTBrm : 0, X86::VMOVQI2PQIrm,
-                         X86::VMOVDI2PDIrm, 1);
+  case X86::VMOVDQUrm: {
+    FixupEntry Fixups[] = {
+        {HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
+        {HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
+        {X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
+        {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
+         rebuildSplatCst},
+        {X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
+        {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
+         rebuildSplatCst},
+    };
+    return FixupConstant(Fixups, 1);
+  }
   case X86::VMOVDQAYrm:
-  case X86::VMOVDQUYrm:
-    return FixupConstant(
-        0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
-        HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
-        HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
-        HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
-        0, 0, 1);
+  case X86::VMOVDQUYrm: {
+    FixupEntry Fixups[] = {
+        {HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
+        {HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
+        {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
+         rebuildSplatCst},
+        {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
+         rebuildSplatCst},
+        {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
+         rebuildSplatCst}};
+    return FixupConstant(Fixups, 1);
+  }
   case X86::VMOVDQA32Z128rm:
   case X86::VMOVDQA64Z128rm:
   case X86::VMOVDQU32Z128rm:
-  case X86::VMOVDQU64Z128rm:
-    return FixupConstant(0, 0, X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
-                         HasBWI ? X86::VPBROADCASTWZ128rm : 0,
-                         HasBWI ? X86::VPBROADCASTBZ128rm : 0,
-                         X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1);
+  case X86::VMOVDQU64Z128rm: {
+    FixupEntry Fixups[] = {
+        {HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
+        {HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
+        {X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
+        {X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
+        {X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
+        {X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst}};
+    return FixupConstant(Fixups, 1);
+  }
   case X86::VMOVDQA32Z256rm:
   case X86::VMOVDQA64Z256rm:
   case X86::VMOVDQU32Z256rm:
-  case X86::VMOVDQU64Z256rm:
-    return FixupConstant(0, X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
-                         X86::VPBROADCASTDZ256rm,
-                         HasBWI ? X86::VPBROADCASTWZ256rm : 0,
-                         HasBWI ? X86::VPBROADCASTBZ256rm : 0, 0, 0, 1);
+  case X86::VMOVDQU64Z256rm: {
+    FixupEntry Fixups[] = {
+        {HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
+        {HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
+        {X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
+        {X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
+        {X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst}};
+    return FixupConstant(Fixups, 1);
+  }
   case X86::VMOVDQA32Zrm:
   case X86::VMOVDQA64Zrm:
   case X86::VMOVDQU32Zrm:
-  case X86::VMOVDQU64Zrm:
-    return FixupConstant(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
-                         X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
-                         HasBWI ? X86::VPBROADCASTWZrm : 0,
-                         HasBWI ? X86::VPBROADCASTBZrm : 0, 0, 0, 1);
+  case X86::VMOVDQU64Zrm: {
+    FixupEntry Fixups[] = {
+        {HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
+        {HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
+        {X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
+        {X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
+        {X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
+        {X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst}};
+    return FixupConstant(Fixups, 1);
+  }
   }
 
   auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -423,7 +463,9 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
 
     if (OpBcst32 || OpBcst64) {
       unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
-      return FixupConstant(0, 0, OpBcst64, OpBcst32, 0, 0, 0, 0, OpNo);
+      FixupEntry Fixups[] = {{(int)OpBcst32, 32, 32, rebuildSplatCst},
+                             {(int)OpBcst64, 64, 64, rebuildSplatCst}};
+      return FixupConstant(Fixups, OpNo);
     }
     return false;
   };


        


More information about the llvm-commits mailing list