[Mlir-commits] [mlir] a776942 - [mlir] squash LLVM_AVX512 dialect into AVX512

Alex Zinenko llvmlistbot at llvm.org
Wed Mar 10 04:07:35 PST 2021


Author: Alex Zinenko
Date: 2021-03-10T13:07:26+01:00
New Revision: a776942ba1aa0e381dd41a01de4b54fc5dc431cd

URL: https://github.com/llvm/llvm-project/commit/a776942ba1aa0e381dd41a01de4b54fc5dc431cd
DIFF: https://github.com/llvm/llvm-project/commit/a776942ba1aa0e381dd41a01de4b54fc5dc431cd.diff

LOG: [mlir] squash LLVM_AVX512 dialect into AVX512

The dialect separation was introduced to demarkate ops operating in different
type systems. This is no longer the case after the LLVM dialect has migrated to
using built-in vector types, so the original reason for separation is no longer
valid. Squash the two dialects into one.

The code size decrease isn't quite large: the ops originally in LLVM_AVX512 are
preserved because they match LLVM IR intrinsics specialized for vector element
bitwidth. However, it is still conceptually beneficial to have only one
dialect. I originally considered to use Tablegen multiclasses to define both
the type-polymorphic op and its two intrinsic-related instantiations, but
decided against it given both the complexity of the required Tablegen input and
its dissimilarity with the rest of ODS-defined ops, both potentially resulting
in very poor maintainability.

Depends On D98327

Reviewed By: nicolasvasilache, springerm

Differential Revision: https://reviews.llvm.org/D98328

Added: 
    mlir/include/mlir/Dialect/AVX512/Transforms.h
    mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h
    mlir/lib/Dialect/AVX512/IR/CMakeLists.txt
    mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt
    mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt
    mlir/test/Dialect/AVX512/legalize-for-llvm.mlir

Modified: 
    mlir/include/mlir/Dialect/AVX512/AVX512.td
    mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/Target/LLVMIR/Dialect/All.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/AVX512/CMakeLists.txt
    mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
    mlir/test/Target/LLVMIR/avx512.mlir
    mlir/test/mlir-opt/commandline.mlir

Removed: 
    mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
    mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h
    mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp
    mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
deleted file mode 100644
index 06f2958a2d5a..000000000000
--- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- ConvertAVX512ToLLVM.h - Conversion Patterns from AVX512 to LLVM ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
-#define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
-
-namespace mlir {
-
-class LLVMTypeConverter;
-class OwningRewritePatternList;
-
-/// Collect a set of patterns to convert from the AVX512 dialect to LLVM.
-void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                            OwningRewritePatternList &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_

diff  --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td
index 391ce74ebdd8..0a32988d684d 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512.td
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td
@@ -14,6 +14,7 @@
 #define AVX512_OPS
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 
 //===----------------------------------------------------------------------===//
 // AVX512 dialect definition
@@ -31,6 +32,24 @@ def AVX512_Dialect : Dialect {
 class AVX512_Op<string mnemonic, list<OpTrait> traits = []> :
   Op<AVX512_Dialect, mnemonic, traits> {}
 
+class AVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<AVX512_Dialect, "intr." # mnemonic,
+                  "x86_avx512_" # !subst(".", "_", mnemonic),
+                  [], [], traits, numResults>;
+
+// Defined by first result overload. May have to be extended for other
+// instructions in the future.
+class AVX512_IntrOverloadedOp<string mnemonic,
+                              list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<AVX512_Dialect, "intr." # mnemonic,
+                  "x86_avx512_" # !subst(".", "_", mnemonic),
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[],
+                  traits, /*numResults=*/1>;
+//----------------------------------------------------------------------------//
+// MaskCompressOp
+//----------------------------------------------------------------------------//
+
 def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
   // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
   // then be removed from assemblyFormat.
@@ -67,6 +86,25 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
                        " `:` type($dst) (`,` type($src)^)?";
 }
 
+def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
+  NoSideEffect,
+  AllTypesMatch<["a", "src", "res"]>,
+  TypesMatchWith<"`k` has the same number of bits as elements in `res`",
+                 "res", "k",
+                 "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+                 "IntegerType::get($_self.getContext(), 1))">]> {
+  let arguments = (ins VectorOfLengthAndType<[16, 8],
+                                             [F32, I32, F64, I64]>:$a,
+                   VectorOfLengthAndType<[16, 8],
+                                         [F32, I32, F64, I64]>:$src,
+                   VectorOfLengthAndType<[16, 8],
+                                         [I1]>:$k);
+}
+
+//----------------------------------------------------------------------------//
+// MaskRndScaleOp
+//----------------------------------------------------------------------------//
+
 def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
   AllTypesMatch<["src", "a", "dst"]>,
   TypesMatchWith<"imm has the same number of bits as elements in dst",
@@ -99,6 +137,30 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
     "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
 }
 
+def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
+  NoSideEffect,
+  AllTypesMatch<["src", "a", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+                   I32:$k,
+                   VectorOfLengthAndType<[16], [F32]>:$a,
+                   I16:$imm,
+                   LLVM_Type:$rounding);
+}
+
+def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
+  NoSideEffect,
+  AllTypesMatch<["src", "a", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
+                   I32:$k,
+                   VectorOfLengthAndType<[8], [F64]>:$a,
+                   I8:$imm,
+                   LLVM_Type:$rounding);
+}
+
+//----------------------------------------------------------------------------//
+// MaskScaleFOp
+//----------------------------------------------------------------------------//
+
 def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
   AllTypesMatch<["src", "a", "b", "dst"]>,
   TypesMatchWith<"k has the same number of bits as elements in dst",
@@ -132,6 +194,30 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
     "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
 }
 
+def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
+  NoSideEffect,
+  AllTypesMatch<["src", "a", "b", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+                   VectorOfLengthAndType<[16], [F32]>:$a,
+                   VectorOfLengthAndType<[16], [F32]>:$b,
+                   I16:$k,
+                   LLVM_Type:$rounding);
+}
+
+def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
+  NoSideEffect,
+  AllTypesMatch<["src", "a", "b", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
+                   VectorOfLengthAndType<[8], [F64]>:$a,
+                   VectorOfLengthAndType<[8], [F64]>:$b,
+                   I8:$k,
+                   LLVM_Type:$rounding);
+}
+
+//----------------------------------------------------------------------------//
+// Vp2IntersectOp
+//----------------------------------------------------------------------------//
+
 def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
   AllTypesMatch<["a", "b"]>,
   TypesMatchWith<"k1 has the same number of bits as elements in a",
@@ -169,4 +255,16 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
     "$a `,` $b attr-dict `:` type($a)";
 }
 
+def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
+  NoSideEffect]> {
+  let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
+                   VectorOfLengthAndType<[16], [I32]>:$b);
+}
+
+def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
+  NoSideEffect]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
+                   VectorOfLengthAndType<[8], [I64]>:$b);
+}
+
 #endif // AVX512_OPS

diff  --git a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
index fc3cb911ba80..07ea8817bc0b 100644
--- a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
@@ -1,2 +1,6 @@
 add_mlir_dialect(AVX512 avx512)
 add_mlir_doc(AVX512 -gen-dialect-doc AVX512 Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS AVX512.td)
+mlir_tablegen(AVX512Conversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRAVX512ConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/AVX512/Transforms.h b/mlir/include/mlir/Dialect/AVX512/Transforms.h
new file mode 100644
index 000000000000..3506f50dc258
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AVX512/Transforms.h
@@ -0,0 +1,29 @@
+//===- Transforms.h - AVX512 Dialect Transformation Entrypoints -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AVX512_TRANSFORMS_H
+#define MLIR_DIALECT_AVX512_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to lower AVX512 ops to ops that map to LLVM
+/// intrinsics.
+void populateAVX512LegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Configure the target to support lowering AVX512 ops to ops that map to LLVM
+/// intrinsics.
+void configureAVX512LegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AVX512_TRANSFORMS_H

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 20b989616d7a..7db04cfb614f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -35,12 +35,6 @@ set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
 mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRROCDLConversionsIncGen)
 
-add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512)
-add_mlir_doc(LLVMAVX512 -gen-dialect-doc LLVMAVX512 Dialects/)
-set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td)
-mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen)
-
 add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
 add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/)
 set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
deleted file mode 100644
index 20fb8030c8b1..000000000000
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
+++ /dev/null
@@ -1,74 +0,0 @@
-//===-- LLVMAVX512.td - LLVMAVX512 dialect op definitions --*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the LLVMAVX512 dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVMIR_AVX512_OPS
-#define LLVMIR_AVX512_OPS
-
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-
-//===----------------------------------------------------------------------===//
-// LLVMAVX512 dialect definition
-//===----------------------------------------------------------------------===//
-
-def LLVMAVX512_Dialect : Dialect {
-  let name = "llvm_avx512";
-  let cppNamespace = "::mlir::LLVM";
-}
-
-//----------------------------------------------------------------------------//
-// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
-//----------------------------------------------------------------------------//
-
-class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
-  LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
-                  "x86_avx512_" # !subst(".", "_", mnemonic),
-                  [], [], traits, numResults>;
-
-// Defined by first result overload. May have to be extended for other
-// instructions in the future.
-class LLVMAVX512_IntrOverloadedOp<string mnemonic,
-                                  list<OpTrait> traits = []> :
-  LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
-                  "x86_avx512_" # !subst(".", "_", mnemonic),
-                  /*list<int> overloadedResults=*/[0],
-                  /*list<int> overloadedOperands=*/[],
-                  traits, /*numResults=*/1>;
-
-def LLVM_x86_avx512_mask_rndscale_ps_512 :
-  LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
-  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_mask_rndscale_pd_512 :
-  LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>,
-  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_mask_scalef_ps_512 :
-  LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>,
-  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_mask_scalef_pd_512 :
-  LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
-  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_mask_compress :
-  LLVMAVX512_IntrOverloadedOp<"mask.compress">,
-  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_vp2intersect_d_512 :
-  LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
-  Arguments<(ins LLVM_Type, LLVM_Type)>;
-
-def LLVM_x86_avx512_vp2intersect_q_512 :
-  LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>,
-  Arguments<(ins LLVM_Type, LLVM_Type)>;
-
-#endif // AVX512_OPS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
deleted file mode 100644
index c028fda514fe..000000000000
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
+++ /dev/null
@@ -1,24 +0,0 @@
-//===- LLVMAVX512Dialect.h - MLIR Dialect for LLVMAVX512 --------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file declares the Target dialect for LLVMAVX512 in MLIR.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_
-#define MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_
-
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/LLVMIR/LLVMAVX512.h.inc"
-
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc"
-
-#endif // MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c1e2a7c31bb0..f53b9e1dded6 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,7 +21,6 @@
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -55,7 +54,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   avx512::AVX512Dialect,
                   complex::ComplexDialect,
                   gpu::GPUDialect,
-                  LLVM::LLVMAVX512Dialect,
                   LLVM::LLVMDialect,
                   LLVM::LLVMArmSVEDialect,
                   linalg::LinalgDialect,

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h
new file mode 100644
index 000000000000..a9e4cf3cb3b9
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h
@@ -0,0 +1,32 @@
+//===- AVX512ToLLVMIRTranslation.h - AVX512 to LLVM IR ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for AVX512 dialect to LLVM IR
+// translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the AVX512 dialect and the translation from it to the LLVM IR
+/// in the given registry;
+void registerAVX512DialectTranslation(DialectRegistry &registry);
+
+/// Register the AVX512 dialect and the translation from it in the registry
+/// associated with the given context.
+void registerAVX512DialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index d07e1fa8df47..97189cb78619 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,8 +14,8 @@
 #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 
+#include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
@@ -29,7 +29,7 @@ class DialectRegistry;
 /// corresponding translation interfaces.
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
-  registerLLVMAVX512DialectTranslation(registry);
+  registerAVX512DialectTranslation(registry);
   registerLLVMArmSVEDialectTranslation(registry);
   registerLLVMDialectTranslation(registry);
   registerNVVMDialectTranslation(registry);

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h
deleted file mode 100644
index 7a871c66edf0..000000000000
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h
+++ /dev/null
@@ -1,32 +0,0 @@
-//===- LLVMAVX512ToLLVMIRTranslation.h - LLVMAVX512 to LLVM IR --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This provides registration calls for LLVMAVX512 dialect to LLVM IR
-// translation.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H
-#define MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H
-
-namespace mlir {
-
-class DialectRegistry;
-class MLIRContext;
-
-/// Register the LLVMAVX512 dialect and the translation from it to the LLVM IR
-/// in the given registry;
-void registerLLVMAVX512DialectTranslation(DialectRegistry &registry);
-
-/// Register the LLVMAVX512 dialect and the translation from it in the registry
-/// associated with the given context.
-void registerLLVMAVX512DialectTranslation(MLIRContext &context);
-
-} // namespace mlir
-
-#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
deleted file mode 100644
index d3257b136cf1..000000000000
--- a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRAVX512ToLLVM
-  ConvertAVX512ToLLVM.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AVX512ToLLVM
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRAVX512
-  MLIRLLVMAVX512
-  MLIRLLVMIR
-  MLIRStandardToLLVM
-  MLIRTransforms
-  )

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
deleted file mode 100644
index 74b919717283..000000000000
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ /dev/null
@@ -1,143 +0,0 @@
-//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
-
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-#include "mlir/Dialect/AVX512/AVX512Dialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
-
-using namespace mlir;
-using namespace mlir::vector;
-using namespace mlir::avx512;
-
-template <typename OpTy>
-static Type getSrcVectorElementType(Operation *op) {
-  return cast<OpTy>(op)
-      .src()
-      .getType()
-      .template cast<VectorType>()
-      .getElementType();
-}
-
-namespace {
-
-// TODO: turn these into simpler declarative templated patterns when we've had
-// enough.
-struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern {
-  explicit MaskRndScaleOp512Conversion(MLIRContext *context,
-                                       LLVMTypeConverter &typeConverter)
-      : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
-                             typeConverter) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Type elementType = getSrcVectorElementType<MaskRndScaleOp>(op);
-    if (elementType.isF32())
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_mask_rndscale_ps_512::getOperationName(),
-          operands, *getTypeConverter(), rewriter);
-    if (elementType.isF64())
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_mask_rndscale_pd_512::getOperationName(),
-          operands, *getTypeConverter(), rewriter);
-    return failure();
-  }
-};
-
-struct MaskCompressOpConversion
-    : public ConvertOpToLLVMPattern<MaskCompressOp> {
-  using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    MaskCompressOp::Adaptor adaptor(operands);
-    auto opType = adaptor.a().getType();
-
-    Value src;
-    if (op.src()) {
-      src = adaptor.src();
-    } else if (op.constant_src()) {
-      src = rewriter.create<ConstantOp>(op.getLoc(), opType,
-                                        op.constant_srcAttr());
-    } else {
-      Attribute zeroAttr = rewriter.getZeroAttr(opType);
-      src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
-    }
-
-    rewriter.replaceOpWithNewOp<LLVM::x86_avx512_mask_compress>(
-        op, opType, adaptor.a(), src, adaptor.k());
-
-    return success();
-  }
-};
-
-struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
-  explicit ScaleFOp512Conversion(MLIRContext *context,
-                                 LLVMTypeConverter &typeConverter)
-      : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
-                             typeConverter) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Type elementType = getSrcVectorElementType<MaskScaleFOp>(op);
-    if (elementType.isF32())
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_mask_scalef_ps_512::getOperationName(), operands,
-          *getTypeConverter(), rewriter);
-    if (elementType.isF64())
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_mask_scalef_pd_512::getOperationName(), operands,
-          *getTypeConverter(), rewriter);
-    return failure();
-  }
-};
-
-struct Vp2IntersectOp512Conversion
-    : public ConvertOpToLLVMPattern<Vp2IntersectOp> {
-  explicit Vp2IntersectOp512Conversion(MLIRContext *context,
-                                       LLVMTypeConverter &typeConverter)
-      : ConvertOpToLLVMPattern<Vp2IntersectOp>(typeConverter) {}
-
-  LogicalResult
-  matchAndRewrite(Vp2IntersectOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Type elementType =
-        op.a().getType().template cast<VectorType>().getElementType();
-    if (elementType.isInteger(32))
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands,
-          *getTypeConverter(), rewriter);
-    if (elementType.isInteger(64))
-      return LLVM::detail::oneToOneRewrite(
-          op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands,
-          *getTypeConverter(), rewriter);
-    return failure();
-  }
-};
-} // namespace
-
-/// Populate the given list with patterns that convert from AVX512 to LLVM.
-void mlir::populateAVX512ToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  // clang-format off
-  patterns.insert<MaskRndScaleOp512Conversion,
-                  ScaleFOp512Conversion,
-                  Vp2IntersectOp512Conversion>(&converter.getContext(),
-                                               converter);
-  patterns.insert<MaskCompressOpConversion>(converter);
-  // clang-format on
-}

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index d17ef3cf3e2e..862e0aafabf1 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_subdirectory(AffineToStandard)
 add_subdirectory(AsyncToLLVM)
-add_subdirectory(AVX512ToLLVM)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(GPUCommon)
 add_subdirectory(GPUToNVVM)

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index e15c3793eb3d..806f4d497017 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -15,8 +15,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   LINK_LIBS PUBLIC
   MLIRArmNeon
   MLIRAVX512
-  MLIRAVX512ToLLVM
-  MLIRLLVMAVX512
+  MLIRAVX512Transforms
   MLIRArmSVE
   MLIRArmSVEToLLVM
   MLIRLLVMArmSVE

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 686ee82e450a..3be8cbfe770d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -10,14 +10,13 @@
 
 #include "../PassDetail.h"
 
-#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
 #include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/AVX512/Transforms.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -45,7 +44,7 @@ struct LowerVectorToLLVMPass
     if (enableArmSVE)
       registry.insert<LLVM::LLVMArmSVEDialect>();
     if (enableAVX512)
-      registry.insert<LLVM::LLVMAVX512Dialect>();
+      registry.insert<avx512::AVX512Dialect>();
   }
   void runOnOperation() override;
 };
@@ -104,9 +103,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateArmSVEToLLVMConversionPatterns(converter, patterns);
   }
   if (enableAVX512) {
-    target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
-    target.addIllegalDialect<avx512::AVX512Dialect>();
-    populateAVX512ToLLVMConversionPatterns(converter, patterns);
+    configureAVX512LegalizeForExportTarget(target);
+    populateAVX512LegalizeForLLVMExportPatterns(converter, patterns);
   }
 
   if (failed(

diff  --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt
index 008add875a19..9f57627c321f 100644
--- a/mlir/lib/Dialect/AVX512/CMakeLists.txt
+++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt
@@ -1,13 +1,2 @@
-add_mlir_dialect_library(MLIRAVX512
-  IR/AVX512Dialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512
-
-  DEPENDS
-  MLIRAVX512IncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRSideEffectInterfaces
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
index 023018af8086..8fba1f8d7ef9 100644
--- a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
+++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"

diff  --git a/mlir/lib/Dialect/AVX512/IR/CMakeLists.txt b/mlir/lib/Dialect/AVX512/IR/CMakeLists.txt
new file mode 100644
index 000000000000..15b5c635ccc7
--- /dev/null
+++ b/mlir/lib/Dialect/AVX512/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRAVX512
+  AVX512Dialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512
+
+  DEPENDS
+  MLIRAVX512IncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIRSideEffectInterfaces
+  )

diff  --git a/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..00ef9deee539
--- /dev/null
+++ b/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRAVX512Transforms
+  LegalizeForLLVMExport.cpp
+
+  DEPENDS
+  MLIRAVX512ConversionsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAVX512
+  MLIRIR
+  MLIRLLVMIR
+  )

diff  --git a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644
index 000000000000..cfe9f2b3ac02
--- /dev/null
+++ b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
@@ -0,0 +1,141 @@
+//===- LegalizeForLLVMExport.cpp - Prepare AVX512 for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AVX512/Transforms.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::avx512;
+
+/// Extracts the "main" vector element type from the given AVX512 operation.
+template <typename OpTy>
+static Type getSrcVectorElementType(OpTy op) {
+  return op.src().getType().template cast<VectorType>().getElementType();
+}
+template <>
+Type getSrcVectorElementType(Vp2IntersectOp op) {
+  return op.a().getType().template cast<VectorType>().getElementType();
+}
+
+namespace {
+/// Base conversion for AVX512 ops that can be lowered to one of the two
+/// intrinsics based on the bitwidth of their "main" vector element type. This
+/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
+/// results of multi-result intrinsic ops.
+template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
+struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
+  explicit LowerToIntrinsic(LLVMTypeConverter &converter)
+      : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
+
+  LLVMTypeConverter &getTypeConverter() const {
+    return *static_cast<LLVMTypeConverter *>(
+        OpConversionPattern<OpTy>::getTypeConverter());
+  }
+
+  LogicalResult
+  matchAndRewrite(OpTy op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type elementType = getSrcVectorElementType<OpTy>(op);
+    unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+    if (bitwidth == 32)
+      return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
+                                           operands, getTypeConverter(),
+                                           rewriter);
+    if (bitwidth == 64)
+      return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
+                                           operands, getTypeConverter(),
+                                           rewriter);
+    return rewriter.notifyMatchFailure(
+        op, "expected 'src' to be either f32 or f64");
+  }
+};
+
+struct MaskCompressOpConversion
+    : public ConvertOpToLLVMPattern<MaskCompressOp> {
+  using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    MaskCompressOp::Adaptor adaptor(operands);
+    auto opType = adaptor.a().getType();
+
+    Value src;
+    if (op.src()) {
+      src = adaptor.src();
+    } else if (op.constant_src()) {
+      src = rewriter.create<ConstantOp>(op.getLoc(), opType,
+                                        op.constant_srcAttr());
+    } else {
+      Attribute zeroAttr = rewriter.getZeroAttr(opType);
+      src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
+    }
+
+    rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
+                                                    src, adaptor.k());
+
+    return success();
+  }
+};
+
+/// An entry associating the "main" AVX512 op with its instantiations for
+/// vectors of 32-bit and 64-bit elements.
+template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
+struct RegEntry {
+  using MainOp = OpTy;
+  using Intr32Op = Intr32OpTy;
+  using Intr64Op = Intr64OpTy;
+};
+
+/// A container for op association entries facilitating the configuration of
+/// dialect conversion.
+template <typename... Args>
+struct RegistryImpl {
+  /// Registers the patterns specializing the "main" op to one of the
+  /// "intrinsic" ops depending on elemental type.
+  static void registerPatterns(LLVMTypeConverter &converter,
+                               OwningRewritePatternList &patterns) {
+    patterns
+        .insert<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
+                                 typename Args::Intr64Op>...>(converter);
+  }
+
+  /// Configures the conversion target to lower out "main" ops.
+  static void configureTarget(LLVMConversionTarget &target) {
+    target.addIllegalOp<typename Args::MainOp...>();
+    target.addLegalOp<typename Args::Intr32Op...>();
+    target.addLegalOp<typename Args::Intr64Op...>();
+  }
+};
+
+using Registry = RegistryImpl<
+    RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
+    RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
+    RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
+
+} // namespace
+
+/// Populate the given list with patterns that convert from AVX512 to LLVM.
+void mlir::populateAVX512LegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  Registry::registerPatterns(converter, patterns);
+  patterns.insert<MaskCompressOpConversion>(converter);
+}
+
+void mlir::configureAVX512LegalizeForExportTarget(
+    LLVMConversionTarget &target) {
+  Registry::configureTarget(target);
+  target.addLegalOp<MaskCompressIntrOp>();
+  target.addIllegalOp<MaskCompressOp>();
+}

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 1b6e4212861f..617fbdb7c9cf 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -29,27 +29,6 @@ add_mlir_dialect_library(MLIRLLVMIR
   MLIRSupport
   )
 
-add_mlir_dialect_library(MLIRLLVMAVX512
-  IR/LLVMAVX512Dialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
-
-  DEPENDS
-  MLIRLLVMAVX512IncGen
-  MLIRLLVMAVX512ConversionsIncGen
-  intrinsics_gen
-
-  LINK_COMPONENTS
-  AsmParser
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMIR
-  MLIRSideEffectInterfaces
-  )
-
 add_mlir_dialect_library(MLIRLLVMArmSVE
   IR/LLVMArmSVEDialect.cpp
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
deleted file mode 100644
index 512234cc8764..000000000000
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- LLVMAVX512Dialect.cpp - MLIR LLVMAVX512 ops implementation ---------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the LLVMAVX512 dialect and its operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "llvm/IR/IntrinsicsX86.h"
-
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/TypeUtilities.h"
-
-using namespace mlir;
-
-void LLVM::LLVMAVX512Dialect::initialize() {
-  addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
-      >();
-}
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"

diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 162d2c42a19b..59b5b850afca 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -37,8 +37,8 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
 
   LINK_LIBS PUBLIC
   MLIRArmNeonToLLVMIRTranslation
+  MLIRAVX512ToLLVMIRTranslation
   MLIRLLVMArmSVEToLLVMIRTranslation
-  MLIRLLVMAVX512ToLLVMIRTranslation
   MLIRLLVMToLLVMIRTranslation
   MLIRNVVMToLLVMIRTranslation
   MLIROpenMPToLLVMIRTranslation

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp
similarity index 60%
rename from mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp
rename to mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp
index 7d1e087dfdb7..2722947e72d1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp
@@ -1,4 +1,4 @@
-//===- LLVMAVX512ToLLVMIRTranslation.cpp - Translate LLVMAVX512 to LLVM IR-===//
+//===- AVX512ToLLVMIRTranslation.cpp - Translate AVX512 to LLVM IR---------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a translation between the MLIR LLVMAVX512 dialect and
+// This file implements a translation between the MLIR AVX512 dialect and
 // LLVM IR.
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h"
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h"
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 
@@ -24,8 +24,8 @@ using namespace mlir::LLVM;
 
 namespace {
 /// Implementation of the dialect interface that converts operations belonging
-/// to the LLVMAVX512 dialect to LLVM IR.
-class LLVMAVX512DialectLLVMIRTranslationInterface
+/// to the AVX512 dialect to LLVM IR.
+class AVX512DialectLLVMIRTranslationInterface
     : public LLVMTranslationDialectInterface {
 public:
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
@@ -36,21 +36,21 @@ class LLVMAVX512DialectLLVMIRTranslationInterface
   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
-#include "mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc"
+#include "mlir/Dialect/AVX512/AVX512Conversions.inc"
 
     return failure();
   }
 };
 } // end namespace
 
-void mlir::registerLLVMAVX512DialectTranslation(DialectRegistry &registry) {
-  registry.insert<LLVM::LLVMAVX512Dialect>();
-  registry.addDialectInterface<LLVM::LLVMAVX512Dialect,
-                               LLVMAVX512DialectLLVMIRTranslationInterface>();
+void mlir::registerAVX512DialectTranslation(DialectRegistry &registry) {
+  registry.insert<avx512::AVX512Dialect>();
+  registry.addDialectInterface<avx512::AVX512Dialect,
+                               AVX512DialectLLVMIRTranslationInterface>();
 }
 
-void mlir::registerLLVMAVX512DialectTranslation(MLIRContext &context) {
+void mlir::registerAVX512DialectTranslation(MLIRContext &context) {
   DialectRegistry registry;
-  registerLLVMAVX512DialectTranslation(registry);
+  registerAVX512DialectTranslation(registry);
   context.appendDialectRegistry(registry);
 }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt
new file mode 100644
index 000000000000..548e0e961266
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRAVX512ToLLVMIRTranslation
+  AVX512ToLLVMIRTranslation.cpp
+
+  DEPENDS
+  MLIRAVX512ConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRAVX512
+  MLIRLLVMIR
+  MLIRSupport
+  MLIRTargetLLVMIRExport
+  )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index d115b13d7123..b710af3bfeb7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_subdirectory(ArmNeon)
+add_subdirectory(AVX512)
 add_subdirectory(LLVMArmSVE)
-add_subdirectory(LLVMAVX512)
 add_subdirectory(LLVMIR)
 add_subdirectory(NVVM)
 add_subdirectory(OpenMP)

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt
deleted file mode 100644
index 3678ab5384ff..000000000000
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_translation_library(MLIRLLVMAVX512ToLLVMIRTranslation
-  LLVMAVX512ToLLVMIRTranslation.cpp
-
-  DEPENDS
-  MLIRLLVMAVX512ConversionsIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMAVX512
-  MLIRLLVMIR
-  MLIRSupport
-  MLIRTargetLLVMIRExport
-  )

diff  --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir
similarity index 79%
rename from mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
rename to mlir/test/Dialect/AVX512/legalize-for-llvm.mlir
index 0d03917d06c3..527a2629461c 100644
--- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir
@@ -3,14 +3,14 @@
 func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
   -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>)
 {
-  // CHECK: llvm_avx512.mask.rndscale.ps.512
+  // CHECK: avx512.intr.mask.rndscale.ps.512
   %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32>
-  // CHECK: llvm_avx512.mask.rndscale.pd.512
+  // CHECK: avx512.intr.mask.rndscale.pd.512
   %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64>
 
-  // CHECK: llvm_avx512.mask.scalef.ps.512
+  // CHECK: avx512.intr.mask.scalef.ps.512
   %2 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
-  // CHECK: llvm_avx512.mask.scalef.pd.512
+  // CHECK: avx512.intr.mask.scalef.pd.512
   %3 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>
 
   // Keep results alive.
@@ -21,11 +21,11 @@ func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
                            %k2: vector<8xi1>, %a2: vector<8xi64>)
   -> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
 {
-  // CHECK: llvm_avx512.mask.compress
+  // CHECK: avx512.intr.mask.compress
   %0 = avx512.mask.compress %k1, %a1 : vector<16xf32>
-  // CHECK: llvm_avx512.mask.compress
+  // CHECK: avx512.intr.mask.compress
   %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
-  // CHECK: llvm_avx512.mask.compress
+  // CHECK: avx512.intr.mask.compress
   %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
   return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
 }
@@ -33,9 +33,9 @@ func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
 func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
   -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
 {
-  // CHECK: llvm_avx512.vp2intersect.d.512
+  // CHECK: avx512.intr.vp2intersect.d.512
   %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
-  // CHECK: llvm_avx512.vp2intersect.q.512
+  // CHECK: avx512.intr.vp2intersect.q.512
   %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
   return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }

diff  --git a/mlir/test/Target/LLVMIR/avx512.mlir b/mlir/test/Target/LLVMIR/avx512.mlir
index abf36bd153a1..6d549f6944db 100644
--- a/mlir/test/Target/LLVMIR/avx512.mlir
+++ b/mlir/test/Target/LLVMIR/avx512.mlir
@@ -7,10 +7,10 @@ llvm.func @LLVM_x86_avx512_mask_ps_512(%a: vector<16 x f32>,
 {
   %b = llvm.mlir.constant(42 : i32) : i32
   // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
-  %0 = "llvm_avx512.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) :
+  %0 = "avx512.intr.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) :
     (vector<16 x f32>, i32, vector<16 x f32>, i16, i32) -> vector<16 x f32>
   // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
-  %1 = "llvm_avx512.mask.scalef.ps.512"(%a, %a, %a, %c, %b) :
+  %1 = "avx512.intr.mask.scalef.ps.512"(%a, %a, %a, %c, %b) :
     (vector<16 x f32>, vector<16 x f32>, vector<16 x f32>, i16, i32) -> vector<16 x f32>
   llvm.return %1: vector<16 x f32>
 }
@@ -22,10 +22,10 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
 {
   %b = llvm.mlir.constant(42 : i32) : i32
   // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
-  %0 = "llvm_avx512.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) :
+  %0 = "avx512.intr.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) :
     (vector<8xf64>, i32, vector<8xf64>, i8, i32) -> vector<8xf64>
   // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
-  %1 = "llvm_avx512.mask.scalef.pd.512"(%a, %a, %a, %c, %b) :
+  %1 = "avx512.intr.mask.scalef.pd.512"(%a, %a, %a, %c, %b) :
     (vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
   llvm.return %1: vector<8xf64>
 }
@@ -35,7 +35,7 @@ llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
   -> vector<16xf32>
 {
   // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
-  %0 = "llvm_avx512.mask.compress"(%a, %a, %k) :
+  %0 = "avx512.intr.mask.compress"(%a, %a, %k) :
     (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32>
   llvm.return %0 : vector<16xf32>
 }
@@ -45,7 +45,7 @@ llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
   -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
 {
   // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
-  %0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) :
+  %0 = "avx512.intr.vp2intersect.d.512"(%a, %b) :
     (vector<16xi32>, vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
   llvm.return %0 : !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
 }
@@ -55,7 +55,7 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
   -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
 {
   // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
-  %0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) :
+  %0 = "avx512.intr.vp2intersect.q.512"(%a, %b) :
     (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
   llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
 }

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 467f10af0804..91d2a8524916 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -11,7 +11,6 @@
 // CHECK-NEXT: linalg
 // CHECK-NEXT: llvm
 // CHECK-NEXT: llvm_arm_sve
-// CHECK-NEXT: llvm_avx512
 // CHECK-NEXT: math
 // CHECK-NEXT: nvvm
 // CHECK-NEXT: omp


        


More information about the Mlir-commits mailing list