[llvm] [WIP][X86] lowerBuildVectorAsBroadcast - don't convert constant vectors to broadcasts on AVX512VL targets (PR #73509)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 20 23:30:44 PDT 2024


================
@@ -586,43 +722,93 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
         {X86::VPMOVZXDQZrm, 8, 32, rebuildZExtCst}};
     return FixupConstant(Fixups, 512, 1);
   }
+  case X86::VMOVDQA32Zrmk:
+  case X86::VMOVDQU32Zrmk:
+    return FixupConstant({{X86::VPBROADCASTDZrmk, 1, 32, rebuildSplatCst},
+                          {X86::VBROADCASTI32X4rmk, 1, 128, rebuildSplatCst},
+                          {X86::VPMOVSXBDZrmk, 16, 8, rebuildSExtCst},
+                          {X86::VPMOVZXBDZrmk, 16, 8, rebuildZExtCst},
+                          {X86::VPMOVSXWDZrmk, 16, 16, rebuildSExtCst},
+                          {X86::VPMOVZXWDZrmk, 16, 16, rebuildZExtCst}},
+                         512, 3);
+  case X86::VMOVDQA32Zrmkz:
+  case X86::VMOVDQU32Zrmkz:
+    return FixupConstant({{X86::VPBROADCASTDZrmkz, 1, 32, rebuildSplatCst},
+                          {X86::VBROADCASTI32X4rmkz, 1, 128, rebuildSplatCst},
+                          {X86::VPMOVSXBDZrmkz, 16, 8, rebuildSExtCst},
+                          {X86::VPMOVZXBDZrmkz, 16, 8, rebuildZExtCst},
+                          {X86::VPMOVSXWDZrmkz, 16, 16, rebuildSExtCst},
+                          {X86::VPMOVZXWDZrmkz, 16, 16, rebuildZExtCst}},
+                         512, 2);
+  case X86::VMOVDQA64Zrmk:
+  case X86::VMOVDQU64Zrmk:
+    return FixupConstant({{X86::VPBROADCASTQZrmk, 1, 64, rebuildSplatCst},
+                          {X86::VPMOVSXBQZrmk, 8, 8, rebuildSExtCst},
+                          {X86::VPMOVZXBQZrmk, 8, 8, rebuildZExtCst},
+                          {X86::VPMOVSXWQZrmk, 8, 16, rebuildSExtCst},
+                          {X86::VPMOVZXWQZrmk, 8, 16, rebuildZExtCst},
+                          {X86::VBROADCASTI64X4rmk, 1, 256, rebuildSplatCst},
+                          {X86::VPMOVSXDQZrmk, 8, 32, rebuildSExtCst},
+                          {X86::VPMOVZXDQZrmk, 8, 32, rebuildZExtCst}},
+                         512, 3);
+  case X86::VMOVDQA64Zrmkz:
+  case X86::VMOVDQU64Zrmkz:
+    return FixupConstant({{X86::VPBROADCASTQZrmkz, 1, 64, rebuildSplatCst},
+                          {X86::VPMOVSXBQZrmkz, 8, 8, rebuildSExtCst},
+                          {X86::VPMOVZXBQZrmkz, 8, 8, rebuildZExtCst},
+                          {X86::VPMOVSXWQZrmkz, 8, 16, rebuildSExtCst},
+                          {X86::VPMOVZXWQZrmkz, 8, 16, rebuildZExtCst},
+                          {X86::VBROADCASTI64X4rmkz, 1, 256, rebuildSplatCst},
+                          {X86::VPMOVSXDQZrmkz, 8, 32, rebuildSExtCst},
+                          {X86::VPMOVZXDQZrmkz, 8, 32, rebuildZExtCst}},
+                         512, 2);
   }
 
-  auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
-    unsigned OpBcst32 = 0, OpBcst64 = 0;
-    unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
+  auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc16, unsigned OpSrc32,
+                                      unsigned OpSrc64) {
+    if (OpSrc16) {
+      if (const X86FoldTableEntry *Mem2Bcst =
+              llvm::lookupBroadcastFoldTableBySize(OpSrc16, 16)) {
+        unsigned OpBcst16 = Mem2Bcst->DstOp;
+        unsigned OpNoBcst16 = Mem2Bcst->Flags & TB_INDEX_MASK;
+        FixupEntry Fixups[] = {{(int)OpBcst16, 1, 16, rebuildSplatCst}};
+        // TODO: Add support for RegBitWidth, but currently rebuildSplatCst
+        // doesn't require it (defaults to Constant::getPrimitiveSizeInBits).
+        if (FixupConstant(Fixups, 0, OpNoBcst16))
+          return true;
+      }
+    }
     if (OpSrc32) {
       if (const X86FoldTableEntry *Mem2Bcst =
               llvm::lookupBroadcastFoldTableBySize(OpSrc32, 32)) {
-        OpBcst32 = Mem2Bcst->DstOp;
-        OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
+        unsigned OpBcst32 = Mem2Bcst->DstOp;
+        unsigned OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
+        FixupEntry Fixups[] = {{(int)OpBcst32, 1, 32, rebuildSplatCst}};
+        // TODO: Add support for RegBitWidth, but currently rebuildSplatCst
+        // doesn't require it (defaults to Constant::getPrimitiveSizeInBits).
+        if (FixupConstant(Fixups, 0, OpNoBcst32))
+          return true;
       }
     }
     if (OpSrc64) {
       if (const X86FoldTableEntry *Mem2Bcst =
               llvm::lookupBroadcastFoldTableBySize(OpSrc64, 64)) {
-        OpBcst64 = Mem2Bcst->DstOp;
-        OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
+        unsigned OpBcst64 = Mem2Bcst->DstOp;
+        unsigned OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
+        FixupEntry Fixups[] = {{(int)OpBcst64, 1, 64, rebuildSplatCst}};
+        // TODO: Add support for RegBitWidth, but currently rebuildSplatCst
+        // doesn't require it (defaults to Constant::getPrimitiveSizeInBits).
+        if (FixupConstant(Fixups, 0, OpNoBcst64))
+          return true;
----------------
goldsteinn wrote:

Maybe make inside of the condition a lambda to avoid 3x duplicate?

https://github.com/llvm/llvm-project/pull/73509


More information about the llvm-commits mailing list