[Mlir-commits] [mlir] 462db62 - [mlir][AVX512] Start a primitive AVX512 dialect
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Mar 20 11:15:06 PDT 2020
Author: Nicolas Vasilache
Date: 2020-03-20T14:11:57-04:00
New Revision: 462db62053fba10d3961448c1a5bd653ada8a87d
URL: https://github.com/llvm/llvm-project/commit/462db62053fba10d3961448c1a5bd653ada8a87d
DIFF: https://github.com/llvm/llvm-project/commit/462db62053fba10d3961448c1a5bd653ada8a87d.diff
LOG: [mlir][AVX512] Start a primitive AVX512 dialect
The Vector Dialect [document](https://mlir.llvm.org/docs/Dialects/Vector/) discusses the vector abstractions that MLIR supports and the various tradeoffs involved.
One of the layer that is missing in OSS atm is the Hardware Vector Ops (HWV) level.
This revision proposes an AVX512-specific to add a new Dialect/Targets/AVX512 Dialect that would directly target AVX512-specific intrinsics.
Atm, we rely too much on LLVM’s peephole optimizer to do a good job from small insertelement/extractelement/shufflevector. In the future, when possible, generic abstractions such as VP intrinsics should be preferred.
The revision will allow trading off HW-specific vs generic abstractions in MLIR.
Differential Revision: https://reviews.llvm.org/D75987
Added:
mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
mlir/include/mlir/Dialect/AVX512/AVX512.td
mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Dialect/AVX512/CMakeLists.txt
mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/AVX512/roundtrip.mlir
mlir/test/Target/avx512.mlir
Modified:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Target/CMakeLists.txt
mlir/tools/mlir-translate/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
new file mode 100644
index 000000000000..bd65970d5bf7
--- /dev/null
+++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
@@ -0,0 +1,29 @@
+//===- ConvertAVX512ToLLVM.h - Conversion Patterns from AVX512 to LLVM ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
+#define MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
+
+#include <memory>
+
+namespace mlir {
+class LLVMTypeConverter;
+class ModuleOp;
+template <typename T> class OpPassBase;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from the AVX512 dialect to LLVM.
+void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns);
+
+/// Create a pass to convert AVX512 operations to the LLVMIR dialect.
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertAVX512ToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td
new file mode 100644
index 000000000000..917af2e1cc04
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td
@@ -0,0 +1,99 @@
+//===-- AVX512Ops.td - AVX512 dialect operation definitions *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the AVX512 dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AVX512_OPS
+#define AVX512_OPS
+
+include "mlir/Interfaces/SideEffects.td"
+
+//===----------------------------------------------------------------------===//
+// AVX512 dialect definition
+//===----------------------------------------------------------------------===//
+
+def AVX512_Dialect : Dialect {
+ let name = "avx512";
+ let cppNamespace = "avx512";
+}
+
+//===----------------------------------------------------------------------===//
+// AVX512 op definitions
+//===----------------------------------------------------------------------===//
+
+class AVX512_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<AVX512_Dialect, mnemonic, traits> {}
+
+def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
+ AllTypesMatch<["src", "a", "dst"]>,
+ TypesMatchWith<"imm has the same number of bits as elements in dst",
+ "dst", "imm",
+ "IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
+ " $_self.getContext())">]> {
+ let summary = "Masked roundscale op";
+ let description = [{
+ The mask.rndscale op is an AVX512 specific op that can lower to the proper
+ LLVMAVX512 operation: `llvm.mask.rndscale.ps.512` or
+ `llvm.mask.rndscale.pd.512` instruction depending on the type of vectors it
+ is applied to.
+
+ From the Intel Intrinsics Guide:
+ ================================
+ Round packed floating-point elements in `a` to the number of fraction bits
+ specified by `imm`, and store the results in `dst` using writemask `k`
+ (elements are copied from src when the corresponding mask bit is not set).
+ }];
+ // Supports vector<16xf32> and vector<8xf64>.
+ let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src,
+ I32:$k,
+ VectorOfLengthAndType<[16, 8], [F32, F64]>:$a,
+ AnyTypeOf<[I16, I8]>:$imm,
+ // TODO(ntv): figure rounding out (optional operand?).
+ I32:$rounding
+ );
+ let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
+ let assemblyFormat =
+ "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
+}
+
+def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
+ AllTypesMatch<["src", "a", "b", "dst"]>,
+ TypesMatchWith<"k has the same number of bits as elements in dst",
+ "dst", "k",
+ "IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
+ " $_self.getContext())">]> {
+ let summary = "ScaleF op";
+ let description = [{
+ The `mask.scalef` op is an AVX512 specific op that can lower to the proper
+ LLVMAVX512 operation: `llvm.mask.scalef.ps.512` or
+ `llvm.mask.scalef.pd.512` depending on the type of MLIR vectors it is
+ applied to.
+
+ From the Intel Intrinsics Guide:
+ ================================
+ Scale the packed floating-point elements in `a` using values from `b`, and
+ store the results in `dst` using writemask `k` (elements are copied from src
+ when the corresponding mask bit is not set).
+ }];
+ // Supports vector<16xf32> and vector<8xf64>.
+ let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src,
+ VectorOfLengthAndType<[16, 8], [F32, F64]>:$a,
+ VectorOfLengthAndType<[16, 8], [F32, F64]>:$b,
+ AnyTypeOf<[I16, I8]>:$k,
+ // TODO(ntv): figure rounding out (optional operand?).
+ I32:$rounding
+ );
+ let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
+ // Fully specified by traits.
+ let assemblyFormat =
+ "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
+}
+
+#endif // AVX512_OPS
diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
new file mode 100644
index 000000000000..aeec2b728a11
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
@@ -0,0 +1,31 @@
+//===- AVX512Dialect.h - MLIR Dialect for AVX512 ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for AVX512 in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_
+#define MLIR_DIALECT_AVX512_AVX512DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffects.h"
+
+namespace mlir {
+namespace avx512 {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AVX512/AVX512.h.inc"
+
+#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc"
+
+} // namespace avx512
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AVX512_AVX512DIALECT_H_
diff --git a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
new file mode 100644
index 000000000000..5868760077a6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt
@@ -0,0 +1 @@
+add_mlir_dialect(AVX512 avx512 AVX512Doc)
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 27cbe9378346..32b24264ba69 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(AffineOps)
+add_subdirectory(AVX512)
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
add_subdirectory(Linalg)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 796b4a68a2b1..99f55c2fb0ef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -20,3 +20,9 @@ add_public_tablegen_target(MLIRNVVMConversionsIncGen)
set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRROCDLConversionsIncGen)
+
+add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512Doc)
+
+set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td)
+mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
new file mode 100644
index 000000000000..12668c4da41b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
@@ -0,0 +1,52 @@
+//===-- LLVMAVX512.td - LLVMAVX512 dialect op definitions --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the LLVMAVX512 dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_AVX512_OPS
+#define LLVMIR_AVX512_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// LLVMAVX512 dialect definition
+//===----------------------------------------------------------------------===//
+
+def LLVMAVX512_Dialect : Dialect {
+ let name = "llvm_avx512";
+ let cppNamespace = "LLVM";
+}
+
+//----------------------------------------------------------------------------//
+// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
+//----------------------------------------------------------------------------//
+
+class LLVMAVX512_IntrOp<string mnemonic, list<OpTrait> traits = []> :
+ LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
+ "x86_avx512_" # !subst(".", "_", mnemonic),
+ [], [], traits, 1>;
+
+def LLVM_x86_avx512_mask_rndscale_ps_512 :
+ LLVMAVX512_IntrOp<"mask.rndscale.ps.512">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_x86_avx512_mask_rndscale_pd_512 :
+ LLVMAVX512_IntrOp<"mask.rndscale.pd.512">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_x86_avx512_mask_scalef_ps_512 :
+ LLVMAVX512_IntrOp<"mask.scalef.ps.512">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_x86_avx512_mask_scalef_pd_512 :
+ LLVMAVX512_IntrOp<"mask.scalef.pd.512">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+#endif // AVX512_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
new file mode 100644
index 000000000000..27b98fd18910
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h
@@ -0,0 +1,30 @@
+//===- LLVMAVX512Dialect.h - MLIR Dialect for LLVMAVX512 --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for LLVMAVX512 in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_
+#define MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace LLVM {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMAVX512.h.inc"
+
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index dbd7b0b0b982..74b42243f236 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -510,8 +510,9 @@ class OpAsmParser {
ArrayRef<Type>(type), loc, result);
}
template <typename Operands, typename Types>
- ParseResult resolveOperands(Operands &&operands, Types &&types,
- llvm::SMLoc loc, SmallVectorImpl<Value> &result) {
+ std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
+ resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc,
+ SmallVectorImpl<Value> &result) {
size_t operandSize = std::distance(operands.begin(), operands.end());
size_t typeSize = std::distance(types.begin(), types.end());
if (operandSize != typeSize)
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c0a7ca04081f..9a14a1586c7f 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -14,9 +14,11 @@
#ifndef MLIR_INITALLDIALECTS_H_
#define MLIR_INITALLDIALECTS_H_
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -38,8 +40,10 @@ namespace mlir {
inline void registerAllDialects() {
static bool init_once = []() {
registerDialect<AffineOpsDialect>();
+ registerDialect<avx512::AVX512Dialect>();
registerDialect<fxpmath::FxpMathOpsDialect>();
registerDialect<gpu::GPUDialect>();
+ registerDialect<LLVM::LLVMAVX512Dialect>();
registerDialect<LLVM::LLVMDialect>();
registerDialect<linalg::LinalgDialect>();
registerDialect<loop::LoopOpsDialect>();
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index b358cfa8802e..c1cac45816df 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -15,6 +15,7 @@
#define MLIR_INITALLPASSES_H_
#include "mlir/Analysis/Passes.h"
+#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
@@ -78,6 +79,9 @@ inline void registerAllPasses() {
createSymbolDCEPass();
createLocationSnapshotPass({});
+ // AVX512
+ createConvertAVX512ToLLVMPass();
+
// GPUtoRODCLPass
createLowerGpuOpsToROCDLOpsPass();
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..5573f6ca1618
--- /dev/null
+++ b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRAVX512ToLLVM
+ ConvertAVX512ToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AVX512ToLLVM
+)
+
+set(LIBS
+ MLIRAVX512
+ MLIRLLVMAVX512
+ MLIRLLVMIR
+ MLIRStandardToLLVM
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRAVX512ToLLVM ${LIBS})
+target_link_libraries(MLIRAVX512ToLLVM PUBLIC ${LIBS})
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
new file mode 100644
index 000000000000..7a8c1e81fcb8
--- /dev/null
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -0,0 +1,193 @@
+//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::vector;
+using namespace mlir::avx512;
+
+template <typename OpTy> Type getSrcVectorElementType(OpTy op) {
+ return op.src().getType().template cast<VectorType>().getElementType();
+}
+
+// TODO(ntv, zinenko): Code is currently copy-pasted and adapted from the code
+// 1-1 LLVM conversion. It would better if it were properly exposed in core and
+// reusable.
+/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
+/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
+/// operands as is, preserve attributes.
+template <typename SourceOp, typename TargetOp>
+LogicalResult matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
+ LLVMTypeConverter &typeConverter,
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) {
+ unsigned numResults = op->getNumResults();
+
+ Type packedType;
+ if (numResults != 0) {
+ packedType = typeConverter.packFunctionResults(op->getResultTypes());
+ if (!packedType)
+ return failure();
+ }
+
+ auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
+ op->getAttrs());
+
+ // If the operation produced 0 or 1 result, return them immediately.
+ if (numResults == 0)
+ return rewriter.eraseOp(op), success();
+ if (numResults == 1)
+ return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
+ success();
+
+ // Otherwise, it had been converted to an operation producing a structure.
+ // Extract individual results from the structure and return them as list.
+ SmallVector<Value, 4> results;
+ results.reserve(numResults);
+ for (unsigned i = 0; i < numResults; ++i) {
+ auto type = typeConverter.convertType(op->getResult(i).getType());
+ results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ op->getLoc(), type, newOp.getOperation()->getResult(0),
+ rewriter.getI64ArrayAttr(i)));
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+// TODO(ntv): Patterns are too verbose due to the fact that we have 1 op (e.g.
+// MaskRndScaleOp) and
diff erent possible target ops. It would be better to take
+// a Functor so that all these conversions become 1-liners.
+struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
+ explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
+ return failure();
+ return matchAndRewriteOneToOne<MaskRndScaleOp,
+ LLVM::x86_avx512_mask_rndscale_ps_512>(
+ *this, this->typeConverter, op, operands, rewriter);
+ }
+};
+
+struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
+ explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
+ return failure();
+ return matchAndRewriteOneToOne<MaskRndScaleOp,
+ LLVM::x86_avx512_mask_rndscale_pd_512>(
+ *this, this->typeConverter, op, operands, rewriter);
+ }
+};
+
+struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
+ explicit ScaleFOpPS512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
+ return failure();
+ return matchAndRewriteOneToOne<MaskScaleFOp,
+ LLVM::x86_avx512_mask_scalef_ps_512>(
+ *this, this->typeConverter, op, operands, rewriter);
+ }
+};
+
+struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
+ explicit ScaleFOpPD512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
+ return failure();
+ return matchAndRewriteOneToOne<MaskScaleFOp,
+ LLVM::x86_avx512_mask_scalef_pd_512>(
+ *this, this->typeConverter, op, operands, rewriter);
+ }
+};
+
+/// Populate the given list with patterns that convert from AVX512 to LLVM.
+void mlir::populateAVX512ToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ // clang-format off
+ patterns.insert<MaskRndScaleOpPS512Conversion,
+ MaskRndScaleOpPD512Conversion,
+ ScaleFOpPS512Conversion,
+ ScaleFOpPD512Conversion>(ctx, converter);
+ // clang-format on
+}
+
+namespace {
+struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void ConvertAVX512ToLLVMPass::runOnModule() {
+ // Convert to the LLVM IR dialect.
+ OwningRewritePatternList patterns;
+ LLVMTypeConverter converter(&getContext());
+ populateAVX512ToLLVMConversionPatterns(converter, patterns);
+ populateVectorToLLVMConversionPatterns(converter, patterns);
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
+ target.addIllegalDialect<avx512::AVX512Dialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ if (failed(
+ applyPartialConversion(getModule(), target, patterns, &converter))) {
+ signalPassFailure();
+ }
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertAVX512ToLLVMPass() {
+ return std::make_unique<ConvertAVX512ToLLVMPass>();
+}
+
+static PassRegistration<ConvertAVX512ToLLVMPass> pass(
+ "convert-avx512-to-llvm",
+ "Convert the operations from the avx512 dialect into the LLVM dialect");
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 2f1826a1e299..fbf3e1259493 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(AffineToStandard)
+add_subdirectory(AVX512ToLLVM)
add_subdirectory(GPUToCUDA)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt
new file mode 100644
index 000000000000..0fc6da6240a9
--- /dev/null
+++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRAVX512
+ IR/AVX512Dialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512
+
+ DEPENDS
+ MLIRAVX512IncGen
+ )
+target_link_libraries(MLIRAVX512
+ PUBLIC
+ MLIRIR
+ LLVMSupport
+ )
diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
new file mode 100644
index 000000000000..aade931ee4e7
--- /dev/null
+++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
@@ -0,0 +1,35 @@
+//===- AVX512Ops.cpp - MLIR AVX512 ops implementation ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the AVX512 dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/AVX512/AVX512.cpp.inc"
+ >();
+}
+
+namespace mlir {
+namespace avx512 {
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AVX512/AVX512.cpp.inc"
+} // namespace avx512
+} // namespace mlir
+
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index fe99044a90e6..0bcc794894cc 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(AVX512)
add_subdirectory(AffineOps)
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 2e53d29f768d..148bc4bef3e8 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -24,6 +24,26 @@ target_link_libraries(MLIRLLVMIR
MLIRSupport
)
+add_mlir_dialect_library(MLIRLLVMAVX512
+ IR/LLVMAVX512Dialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+ DEPENDS
+ MLIRLLVMAVX512IncGen
+ MLIRLLVMAVX512ConversionsIncGen
+ )
+target_link_libraries(MLIRLLVMAVX512
+ PUBLIC
+ LLVMAsmParser
+ MLIRIR
+ MLIRLLVMIR
+ MLIRSideEffects
+ LLVMSupport
+ LLVMCore
+ )
+
add_mlir_dialect_library(MLIRNVVMIR
IR/NVVMDialect.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
new file mode 100644
index 000000000000..bde81144fb54
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
@@ -0,0 +1,36 @@
+//===- LLVMAVX512Dialect.cpp - MLIR LLVMAVX512 ops implementation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the LLVMAVX512 dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/IntrinsicsX86.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
+ >();
+}
+
+namespace mlir {
+namespace LLVM {
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
+} // namespace LLVM
+} // namespace mlir
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index b68bfa8d3cf2..9bc37cab1093 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -18,6 +18,22 @@ target_link_libraries(MLIRTargetLLVMIRModuleTranslation
MLIRTranslation
)
+add_mlir_library(MLIRTargetAVX512
+ LLVMIR/LLVMAVX512Intr.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+ DEPENDS
+ MLIRLLVMAVX512ConversionsIncGen
+ )
+target_link_libraries(MLIRTargetAVX512
+ PUBLIC
+ MLIRIR
+ MLIRLLVMAVX512
+ MLIRLLVMIR
+ MLIRTargetLLVMIRModuleTranslation
+ )
+
add_mlir_library(MLIRTargetLLVMIR
LLVMIR/ConvertFromLLVMIR.cpp
LLVMIR/ConvertToLLVMIR.cpp
diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
new file mode 100644
index 000000000000..216ae862d4b2
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
@@ -0,0 +1,51 @@
+//===- AVX512Intr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM and AVX512 dialects
+// and LLVM IR with AVX intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+#include "llvm/IR/IntrinsicsX86.h"
+
+using namespace mlir;
+
+namespace {
+class LLVMAVX512ModuleTranslation : public LLVM::ModuleTranslation {
+ friend LLVM::ModuleTranslation;
+
+public:
+ using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+ LogicalResult convertOperation(Operation &opInst,
+ llvm::IRBuilder<> &builder) override {
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc"
+
+ return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+ }
+};
+
+std::unique_ptr<llvm::Module> translateLLVMAVX512ModuleToLLVMIR(Operation *m) {
+ return LLVM::ModuleTranslation::translateModule<LLVMAVX512ModuleTranslation>(
+ m);
+}
+} // end namespace
+
+static TranslateFromMLIRRegistration
+ reg("avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+ auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(module);
+ if (!llvmModule)
+ return failure();
+
+ llvmModule->print(output, nullptr);
+ return success();
+ });
diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..936819e27eb9
--- /dev/null
+++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s
+
+func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
+ -> (vector<16xf32>, vector<8xf64>)
+{
+ // CHECK: llvm_avx512.mask.rndscale.ps.512
+ %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32>
+ // CHECK: llvm_avx512.mask.rndscale.pd.512
+ %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64>
+
+ // CHECK: llvm_avx512.mask.scalef.ps.512
+ %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
+ // CHECK: llvm_avx512.mask.scalef.pd.512
+ %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>
+
+ return %a0, %a1: vector<16xf32>, vector<8xf64>
+}
diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir
new file mode 100644
index 000000000000..bd23103fa432
--- /dev/null
+++ b/mlir/test/Dialect/AVX512/roundtrip.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
+ -> (vector<16xf32>, vector<8xf64>)
+{
+ // CHECK: avx512.mask.rndscale {{.*}}: vector<16xf32>
+ %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
+ // CHECK: avx512.mask.rndscale {{.*}}: vector<8xf64>
+ %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
+ return %0, %1: vector<16xf32>, vector<8xf64>
+}
+
+func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
+ -> (vector<16xf32>, vector<8xf64>)
+{
+ // CHECK: avx512.mask.scalef {{.*}}: vector<16xf32>
+ %0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
+ // CHECK: avx512.mask.scalef {{.*}}: vector<8xf64>
+ %1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
+ return %0, %1: vector<16xf32>, vector<8xf64>
+}
diff --git a/mlir/test/Target/avx512.mlir b/mlir/test/Target/avx512.mlir
new file mode 100644
index 000000000000..5e75a98dc4ef
--- /dev/null
+++ b/mlir/test/Target/avx512.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --avx512-mlir-to-llvmir | FileCheck %s
+
+// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512
+llvm.func @LLVM_x86_avx512_mask_ps_512(%a: !llvm<"<16 x float>">,
+ %b: !llvm.i32,
+ %c: !llvm.i16)
+ -> (!llvm<"<16 x float>">)
+{
+ // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
+ %0 = "llvm_avx512.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) :
+ (!llvm<"<16 x float>">, !llvm.i32, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>">
+ // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
+ %1 = "llvm_avx512.mask.scalef.ps.512"(%a, %a, %a, %c, %b) :
+ (!llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>">
+ llvm.return %1: !llvm<"<16 x float>">
+}
+
+// CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512
+llvm.func @LLVM_x86_avx512_mask_pd_512(%a: !llvm<"<8 x double>">,
+ %b: !llvm.i32,
+ %c: !llvm.i8)
+ -> (!llvm<"<8 x double>">)
+{
+ // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
+ %0 = "llvm_avx512.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) :
+ (!llvm<"<8 x double>">, !llvm.i32, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>">
+ // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
+ %1 = "llvm_avx512.mask.scalef.pd.512"(%a, %a, %a, %c, %b) :
+ (!llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>">
+ llvm.return %1: !llvm<"<8 x double>">
+}
diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt
index d665789e5bd0..bf7a92509912 100644
--- a/mlir/tools/mlir-translate/CMakeLists.txt
+++ b/mlir/tools/mlir-translate/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LIBS
MLIRPass
MLIRSPIRV
MLIRSPIRVSerialization
+ MLIRTargetAVX512
MLIRTargetLLVMIR
MLIRTargetNVVMIR
MLIRTargetROCDLIR
@@ -13,6 +14,7 @@ set(LIBS
)
set(FULL_LIBS
MLIRSPIRVSerialization
+ MLIRTargetAVX512
MLIRTargetLLVMIR
MLIRTargetNVVMIR
MLIRTargetROCDLIR
More information about the Mlir-commits
mailing list