[Mlir-commits] [mlir] 6ad7b97 - [mlir][amx] Add Intel AMX dialect (architectural-specific vector dialect)
Aart Bik
llvmlistbot at llvm.org
Mon Mar 15 17:59:16 PDT 2021
Author: Aart Bik
Date: 2021-03-15T17:59:05-07:00
New Revision: 6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b
URL: https://github.com/llvm/llvm-project/commit/6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b
DIFF: https://github.com/llvm/llvm-project/commit/6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b.diff
LOG: [mlir][amx] Add Intel AMX dialect (architectural-specific vector dialect)
The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA). This new MLIR dialect
provides a bridge between MLIR concepts like vectors and memrefs
and the lower level LLVM IR details of AMX.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D98470
Added:
mlir/include/mlir/Dialect/AMX/AMX.td
mlir/include/mlir/Dialect/AMX/AMXDialect.h
mlir/include/mlir/Dialect/AMX/CMakeLists.txt
mlir/include/mlir/Dialect/AMX/Transforms.h
mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
mlir/lib/Dialect/AMX/CMakeLists.txt
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/lib/Dialect/AMX/IR/CMakeLists.txt
mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
mlir/test/Dialect/AMX/invalid.mlir
mlir/test/Dialect/AMX/legalize-for-llvm.mlir
mlir/test/Dialect/AMX/roundtrip.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
mlir/test/Target/LLVMIR/amx.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Target/LLVMIR/Dialect/All.h
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Target/LLVMIR/CMakeLists.txt
mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
mlir/test/CMakeLists.txt
mlir/test/lit.site.cfg.py.in
mlir/test/mlir-opt/commandline.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f2e3fc3d3d24..79af5b3d0e15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -502,7 +502,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
- (AVX512, ArmNeon, ArmSVE, etc.) in combination with the
+ (AMX, AVX512, ArmNeon, ArmSVE, etc.) in combination with the
architectural-neutral vector dialect lowering.
}];
@@ -517,6 +517,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
"bool", /*default=*/"true",
"Allows compiler to assume indices fit in 32-bit if that yields "
"faster code">,
+ Option<"enableAMX", "enable-amx",
+ "bool", /*default=*/"false",
+ "Enables the use of AMX dialect while lowering the vector "
+ "dialect.">,
Option<"enableAVX512", "enable-avx512",
"bool", /*default=*/"false",
"Enables the use of AVX512 dialect while lowering the vector "
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 8d24803eeb1c..91ded03f84b0 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,8 @@ class OperationPass;
struct LowerVectorToLLVMOptions {
LowerVectorToLLVMOptions()
: reassociateFPReductions(false), enableIndexOptimizations(true),
- enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
+ enableArmNeon(false), enableArmSVE(false), enableAMX(false),
+ enableAVX512(false) {}
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
reassociateFPReductions = b;
@@ -41,6 +42,10 @@ struct LowerVectorToLLVMOptions {
enableArmSVE = b;
return *this;
}
+ LowerVectorToLLVMOptions &setEnableAMX(bool b) {
+ enableAMX = b;
+ return *this;
+ }
LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
enableAVX512 = b;
return *this;
@@ -50,6 +55,7 @@ struct LowerVectorToLLVMOptions {
bool enableIndexOptimizations;
bool enableArmNeon;
bool enableArmSVE;
+ bool enableAMX;
bool enableAVX512;
};
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
new file mode 100644
index 000000000000..710387e70b55
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -0,0 +1,292 @@
+//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the AMX dialect.
+//
+// The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
+// multiply unit (TMUL), a tile control register (TILECFG), and eight
+// tile registers TMM0 through TMM7 (TILEDATA).
+//
+// The AMX dialect provides a bridge between MLIR concepts, such as
+// 2-d vector, operations, and memrefs, and the lower level details
+// of Intel AMX, such as configuration setup, tile sizes, instructions,
+// and tile release.
+//
+// Note that since configuration changes (implicit at dialect level) are
+// costly, it is highly recommended to use the AMX dialect on same-shaped
+// vectors, at least within a single method.
+//
+// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX
+#define AMX
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// AMX dialect definition.
+//===----------------------------------------------------------------------===//
+
+def AMX_Dialect : Dialect {
+ let name = "amx";
+ let cppNamespace = "::mlir::amx";
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Op and IntrOp definitions.
+//===----------------------------------------------------------------------===//
+
+class AMX_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<AMX_Dialect, mnemonic, traits> {}
+
+// The "internal" intrinsics are meant for compiler usage.
+class AMX_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
+ LLVM_IntrOpBase<AMX_Dialect, mnemonic,
+ "x86_" # !subst(".", "_", mnemonic) # "_internal",
+ [], [], traits, numResults>;
+
+//===----------------------------------------------------------------------===//
+// AMX Op definitions (user facing).
+//===----------------------------------------------------------------------===//
+
+//
+// Tile reset.
+//
+
+def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
+ let summary = "tile zero operation";
+ let description = [{
+ Zeroes the destination tile, with the shape defined by the 2-dim
+ vector type of the result. This is eventually lowered into the
+ "tilezero" instruction with the corresponding tile configuration.
+
+ Example:
+
+ ```mlir
+ %0 = amx.tilezero : vector<16x16xbf16>
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let results = (outs
+ VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return res().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "attr-dict `:` type($res)";
+}
+
+//
+// Tile memory operations.
+//
+
+def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
+ let summary = "tile load operation";
+ let description = [{
+ Loads a tile from memory defined by a base and indices, with the
+ shape defined by the 2-dim vector type of the result. This is
+ eventually lowered into the "tileloadd" instruction with the
+ corresponding tile configuration.
+
+ Example:
+
+ ```mlir
+ %0 = amx.tileload %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
+ Variadic<Index>:$indices);
+ let results = (outs
+ VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getVectorType() {
+ return res().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
+ "type($base) `into` type($res)";
+}
+
+def TileStoreOp : AMX_Op<"tile_store"> {
+ let summary = "tile store operation";
+ let description = [{
+ Stores a tile to memory defined by a base and indices, with the
+ shape defined by the 2-dim vector type of the value. This is
+ eventually lowered into the "tilestored" instruction with the
+ corresponding tile configuration.
+
+ Example:
+
+ ```mlir
+ amx.tilestore %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
+ Variadic<Index>:$indices,
+ VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getVectorType() {
+ return val().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
+ "type($base) `,` type($val)";
+}
+
+//
+// Tile arithmetic operations.
+//
+
+def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
+ let summary = "tile multiplication operation (floating-point)";
+ let description = [{
+ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+ into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
+ pairs of "bf16"). The operation is eventually lowered into the
+ "tdpbf16ps" instruction with the corresponding tile configuration.
+
+ Example:
+
+ ```mlir
+ %0 = amx.tilemulf %a, %b, %c
+ : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
+ VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
+ VectorOfRankAndType<[2], [F32, BF16]>:$acc);
+ let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+ let extraClassDeclaration = [{
+ VectorType getLhsVectorType() {
+ return lhs().getType().cast<VectorType>();
+ }
+ VectorType getRhsVectorType() {
+ return rhs().getType().cast<VectorType>();
+ }
+ VectorType getVectorType() {
+ return res().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
+ "type($lhs) `,` type($rhs) `,` type($acc) ";
+}
+
+def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
+ let summary = "tile multiplication operation (integer)";
+ let description = [{
+ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+ into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
+ combinations (4 bytes packed into dwords in the columns of both the
+ source operand tiles; the zero or sign extension is specified with
+ the attributes). The operation is eventually lowered into one of
+ the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
+ the corresponding tile configuration.
+
+ Example:
+
+ ```mlir
+ %0 = amx.tilemuli %a, %b, %c [true, true]
+ : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
+ VectorOfRankAndType<[2], [I32, I8]>:$rhs,
+ VectorOfRankAndType<[2], [I32, I8]>:$acc,
+ BoolArrayAttr:$zext);
+ let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+ let extraClassDeclaration = [{
+ VectorType getLhsVectorType() {
+ return lhs().getType().cast<VectorType>();
+ }
+ VectorType getRhsVectorType() {
+ return rhs().getType().cast<VectorType>();
+ }
+ VectorType getVectorType() {
+ return res().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
+ "type($lhs) `,` type($rhs) `,` type($acc) ";
+}
+
+//===----------------------------------------------------------------------===//
+// AMX IntrOp definitions (LLVM compiler facing).
+//===----------------------------------------------------------------------===//
+
+//
+// Tile reset. Parameters define the tile size.
+//
+
+def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
+ Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>;
+
+//
+// Tile memory operations. Parameters define the tile size,
+// base address, and stride between consecutive rows for the
+// memory operation.
+//
+
+def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>;
+
+def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>;
+
+//
+// Tile multiplication operations (series of dot products). Parameters
+// define the tile sizes and source and destination tiles for the
+// operation. Note that the prefix "tdp" stands for tile dot product.
+//
+
+// Dot product of bf16 tiles into f32 tile.
+def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/sign extension).
+def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/zero extension).
+def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/sign extension).
+def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/zero extension).
+def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
+ Arguments<(ins LLVM_AnyInteger,
+ LLVM_AnyInteger,
+ LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+#endif // AMX
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
new file mode 100644
index 000000000000..8439c2af82c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -0,0 +1,26 @@
+//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for AMX in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_
+#define MLIR_DIALECT_AMX_AMXDIALECT_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/AMX.h.inc"
+
+#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..4317fd84ac14
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(AMX amx)
+add_mlir_doc(AMX -gen-dialect-doc AMX Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS AMX.td)
+mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRAMXConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
new file mode 100644
index 000000000000..11b3004292d4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -0,0 +1,29 @@
+//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
+#define MLIR_DIALECT_AMX_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
+/// intrinsics.
+void populateAMXLegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Configure the target to support lowering AMX ops to ops that map to LLVM
+/// intrinsics.
+void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 8abad863ba32..20bccbfb2971 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Affine)
add_subdirectory(Async)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
+add_subdirectory(AMX)
add_subdirectory(AVX512)
add_subdirectory(Complex)
add_subdirectory(DLTI)
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 92cb5a3c3bcc..edfd003dbe8e 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -14,6 +14,7 @@
#ifndef MLIR_INITALLDIALECTS_H_
#define MLIR_INITALLDIALECTS_H_
+#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -51,6 +52,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
// clang-format off
registry.insert<acc::OpenACCDialect,
AffineDialect,
+ amx::AMXDialect,
arm_neon::ArmNeonDialect,
async::AsyncDialect,
avx512::AVX512Dialect,
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
new file mode 100644
index 000000000000..4525ec321219
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for AMX dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the AMX dialect and the translation from it to the LLVM IR
+/// in the given registry;
+void registerAMXDialectTranslation(DialectRegistry ®istry);
+
+/// Register the AMX dialect and the translation from it in the registry
+/// associated with the given context.
+void registerAMXDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 97189cb78619..47907fde2042 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,6 +14,7 @@
#ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
#define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
+#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h"
@@ -29,6 +30,7 @@ class DialectRegistry;
/// corresponding translation interfaces.
static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) {
registerArmNeonDialectTranslation(registry);
+ registerAMXDialectTranslation(registry);
registerAVX512DialectTranslation(registry);
registerLLVMArmSVEDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 4e9e0861d312..f912a30803c8 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -30,7 +30,6 @@ class GPUModuleOp;
namespace LLVM {
class LLVMArmSVEDialect;
-class LLVMAVX512Dialect;
class LLVMDialect;
} // end namespace LLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index ace51f8b71d4..9a5683a9168e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -14,6 +14,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
LINK_LIBS PUBLIC
MLIRArmNeon
+ MLIRAMX
+ MLIRAMXTransforms
MLIRAVX512
MLIRAVX512Transforms
MLIRArmSVE
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 207a06e584c2..85657742413f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -13,6 +13,8 @@
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/AVX512/Transforms.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -35,6 +37,7 @@ struct LowerVectorToLLVMPass
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
this->enableArmSVE = options.enableArmSVE;
+ this->enableAMX = options.enableAMX;
this->enableAVX512 = options.enableAVX512;
}
// Override explicitly to allow conditional dialect dependence.
@@ -45,6 +48,8 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (enableArmSVE)
registry.insert<LLVM::LLVMArmSVEDialect>();
+ if (enableAMX)
+ registry.insert<amx::AMXDialect>();
if (enableAVX512)
registry.insert<avx512::AVX512Dialect>();
}
@@ -105,6 +110,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
});
populateArmSVEToLLVMConversionPatterns(converter, patterns);
}
+ if (enableAMX) {
+ configureAMXLegalizeForExportTarget(target);
+ populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
+ }
if (enableAVX512) {
configureAVX512LegalizeForExportTarget(target);
populateAVX512LegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..9f57627c321f
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
new file mode 100644
index 000000000000..5ebef7efe213
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -0,0 +1,106 @@
+//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the AMX dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void amx::AMXDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/AMX.cpp.inc"
+ >();
+}
+
+/// Verify that AMX supports the implied tile shape.
+static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+ const unsigned kMaxRows = 16;
+ const unsigned kBitsPerRow = 64 * 8;
+ unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
+ if (tp.getDimSize(0) > kMaxRows)
+ return op->emitOpError("bad row height: ") << tp.getDimSize(0);
+ if (col > kBitsPerRow || col & 0x1f)
+ return op->emitOpError("bad column width: ") << (col >> 3);
+ return success();
+}
+
+/// Verify that AMX supports the multiplication.
+static LogicalResult verifyMultShape(Operation *op, VectorType atp,
+ VectorType btp, VectorType ctp,
+ unsigned scale) {
+ unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
+ unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
+ unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
+ if (cm != am || cn != bn || ak != bk)
+ return op->emitOpError("bad mult shape: ")
+ << cm << " x " << cn << " x " << ak;
+ return success();
+}
+
+static LogicalResult verify(amx::TileZeroOp op) {
+ return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileLoadOp op) {
+ unsigned rank = op.getMemRefType().getRank();
+ if (llvm::size(op.indices()) != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+ return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileStoreOp op) {
+ unsigned rank = op.getMemRefType().getRank();
+ if (llvm::size(op.indices()) != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+ return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileMulFOp op) {
+ VectorType aType = op.getLhsVectorType();
+ VectorType bType = op.getRhsVectorType();
+ VectorType cType = op.getVectorType();
+ if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
+ failed(verifyTileSize(op, cType)) ||
+ failed(verifyMultShape(op, aType, bType, cType, 1)))
+ return failure();
+ Type ta = aType.getElementType();
+ Type tb = bType.getElementType();
+ Type tc = cType.getElementType();
+ if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+ return op.emitOpError("unsupported type combination");
+ return success();
+}
+
+static LogicalResult verify(amx::TileMulIOp op) {
+ if (op.zext().size() != 2)
+ return op.emitOpError("unexpected zext length");
+ VectorType aType = op.getLhsVectorType();
+ VectorType bType = op.getRhsVectorType();
+ VectorType cType = op.getVectorType();
+ if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
+ failed(verifyTileSize(op, cType)) ||
+ failed(verifyMultShape(op, aType, bType, cType, 2)))
+ return failure();
+ Type ta = aType.getElementType();
+ Type tb = bType.getElementType();
+ Type tc = cType.getElementType();
+ if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
+ return op.emitOpError("unsupported type combination");
+ return success();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/AMX.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
new file mode 100644
index 000000000000..8a3a7f892551
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRAMX
+ AMXDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX
+
+ DEPENDS
+ MLIRAMXIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMIR
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..d7cf9e1e58cb
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRAMXTransforms
+ LegalizeForLLVMExport.cpp
+
+ DEPENDS
+ MLIRAMXConversionsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMX
+ MLIRIR
+ MLIRLLVMIR
+ MLIRStandardToLLVM
+ )
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644
index 000000000000..6e082ce790fc
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -0,0 +1,230 @@
+//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/Transforms.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+
+namespace {
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ VectorType vType, Location loc) {
+ Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
+ unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+ assert(llvm::isPowerOf2_64(width) && width >= 8);
+ unsigned bytes = width >> 3;
+ auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
+ auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+ return std::make_pair(
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
+}
+
+/// Verifies if the stride matches proper tile access.
+LogicalResult verifyStride(MemRefType mType) {
+ if (mType.getRank() < 2)
+ return failure();
+ int64_t last = mType.getRank() - 1;
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
+ return failure();
+ return success();
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+Value getStride(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter, MemRefType mType, Value base,
+ Location loc) {
+ assert(mType.getRank() >= 2);
+ int64_t last = mType.getRank() - 1;
+ Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
+ unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+ assert(llvm::isPowerOf2_64(width) && width >= 8);
+ unsigned bytes = width >> 3;
+ if (mType.isDynamicDim(last)) {
+ // Dynamic size needs code to compute the stride at runtime.
+ MemRefDescriptor memrefDescriptor(base);
+ auto attr = rewriter.getI64IntegerAttr(bytes);
+ Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+ return rewriter.create<LLVM::MulOp>(
+ loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+ }
+ // Use direct constant for static size.
+ auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+ return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+}
+
+/// Cast any pointer to the !llvm.ptr<i8> pointer type.
+Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) {
+ auto i8Ptr =
+ LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8));
+ return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr);
+}
+
+struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
+ using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(TileZeroOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Replace operation with intrinsic.
+ Type resType = typeConverter->convertType(vType);
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
+ tsz.second);
+ return success();
+ }
+};
+
+struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
+ using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(TileLoadOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TileLoadOp::Adaptor adaptor(operands);
+ MemRefType mType = op.getMemRefType();
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Determine stride.
+ if (failed(verifyStride(mType)))
+ return failure();
+ Value stride = getStride(rewriter, *getTypeConverter(), mType,
+ adaptor.base(), op.getLoc());
+ // Replace operation with intrinsic.
+ Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ ptr = castPtr(rewriter, op.getLoc(), ptr);
+ Type resType = typeConverter->convertType(vType);
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
+ op, resType, tsz.first, tsz.second, ptr, stride);
+ return success();
+ }
+};
+
+struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
+ using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(TileStoreOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TileStoreOp::Adaptor adaptor(operands);
+ MemRefType mType = op.getMemRefType();
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Determine stride.
+ if (failed(verifyStride(mType)))
+ return failure();
+ Value stride = getStride(rewriter, *getTypeConverter(), mType,
+ adaptor.base(), op.getLoc());
+ // Replace operation with intrinsic.
+ Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ ptr = castPtr(rewriter, op.getLoc(), ptr);
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
+ op, tsz.first, tsz.second, ptr, stride, adaptor.val());
+ return success();
+ }
+};
+
+struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
+ using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(TileMulFOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TileMulFOp::Adaptor adaptor(operands);
+ VectorType aType = op.getLhsVectorType();
+ VectorType bType = op.getRhsVectorType();
+ VectorType cType = op.getVectorType();
+ // Determine m x n x k tile sizes.
+ std::pair<Value, Value> tsza =
+ getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
+ std::pair<Value, Value> tszb =
+ getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
+ // Replace operation with intrinsic.
+ Type resType = typeConverter->convertType(cType);
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+ adaptor.lhs(), adaptor.rhs());
+ return success();
+ }
+};
+
+struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
+ using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(TileMulIOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TileMulIOp::Adaptor adaptor(operands);
+ VectorType aType = op.getLhsVectorType();
+ VectorType bType = op.getRhsVectorType();
+ VectorType cType = op.getVectorType();
+ // Determine m x n x k tile sizes.
+ std::pair<Value, Value> tsza =
+ getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
+ std::pair<Value, Value> tszb =
+ getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
+ // Replace operation with intrinsic.
+ Type resType = typeConverter->convertType(cType);
+ bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
+ bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
+ if (zexta && zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+ adaptor.lhs(), adaptor.rhs());
+ else if (zexta && !zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+ adaptor.lhs(), adaptor.rhs());
+ else if (!zexta && zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+ adaptor.lhs(), adaptor.rhs());
+ else
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+ adaptor.lhs(), adaptor.rhs());
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateAMXLegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ patterns.insert<TileZeroConversion, TileLoadConversion, TileStoreConversion,
+ TileMulFConversion, TileMulIConversion>(converter);
+}
+
+void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
+ target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
+ x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
+ x86_amx_tdpbusd, x86_amx_tdpbuud>();
+ target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
+ TileMulFOp>();
+}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b21ef9308af2..b4824707131d 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Affine)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
add_subdirectory(Async)
+add_subdirectory(AMX)
add_subdirectory(AVX512)
add_subdirectory(Complex)
add_subdirectory(DLTI)
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 59b5b850afca..4f3465440784 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -37,6 +37,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
LINK_LIBS PUBLIC
MLIRArmNeonToLLVMIRTranslation
+ MLIRAMXToLLVMIRTranslation
MLIRAVX512ToLLVMIRTranslation
MLIRLLVMArmSVEToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
new file mode 100644
index 000000000000..f923367796c5
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
@@ -0,0 +1,55 @@
+//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the AMX dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsX86.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the AMX dialect to LLVM IR.
+class AMXDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Translates the given operation to LLVM IR using the provided IR builder
+ /// and saving the state in `moduleTranslation`.
+ LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ Operation &opInst = *op;
+#include "mlir/Dialect/AMX/AMXConversions.inc"
+
+ return failure();
+ }
+};
+} // end namespace
+
+void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) {
+ registry.insert<amx::AMXDialect>();
+ registry.addDialectInterface<amx::AMXDialect,
+ AMXDialectLLVMIRTranslationInterface>();
+}
+
+void mlir::registerAMXDialectTranslation(MLIRContext &context) {
+ DialectRegistry registry;
+ registerAMXDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..f7f7b583fd65
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRAMXToLLVMIRTranslation
+ AMXToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRAMXConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRAMX
+ MLIRLLVMIR
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+ )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index b710af3bfeb7..ebf74740d06c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(ArmNeon)
+add_subdirectory(AMX)
add_subdirectory(AVX512)
add_subdirectory(LLVMArmSVE)
add_subdirectory(LLVMIR)
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index e30cca13f92a..69d123d02047 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -29,6 +29,7 @@ set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
if (MLIR_INCLUDE_INTEGRATION_TESTS)
set(INTEL_SDE_EXECUTABLE "" CACHE STRING
"If set, arch-specific integration tests are run with Intel SDE.")
+ option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
option(MLIR_RUN_AVX512_TESTS "Run AVX512 tests.")
# Passed to lit.site.cfg.py.in to set up the path where to find the libraries.
set(MLIR_INTEGRATION_TEST_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
new file mode 100644
index 000000000000..b3a7286b526a
--- /dev/null
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func @rowheight() {
+ // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
+ %0 = amx.tile_zero : vector<17x16xbf16>
+}
+
+// -----
+
+func @colwidth() {
+ // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
+ %0 = amx.tile_zero : vector<16x65xi8>
+}
+
+// -----
+
+func @col4bytemultiple() {
+ // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
+ %0 = amx.tile_zero : vector<16x5xi8>
+}
+
+// -----
+
+func @memtilesize(%arg0: memref<?x?xf32>) {
+ %0 = constant 0 : index
+ // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+}
+
+// -----
+
+func @memindexsize(%arg0: memref<?x?xf32>) {
+ %0 = constant 0 : index
+ // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
+ %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+}
+
+// -----
+
+func @multsize() {
+ %0 = amx.tile_zero : vector<8x8xbf16>
+ %1 = amx.tile_zero : vector<8x8xbf16>
+ %2 = amx.tile_zero : vector<4x4xf32>
+ // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
+ %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+}
+
+// -----
+
+func @zextsize() {
+ %0 = amx.tile_zero : vector<8x8xi8>
+ %1 = amx.tile_zero : vector<8x8xi8>
+ %2 = amx.tile_zero : vector<8x8xi32>
+ // expected-error at +1 {{'amx.tile_muli' op unexpected zext length}}
+ %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
+}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
new file mode 100644
index 000000000000..f88d83d8f311
--- /dev/null
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: muli(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpbuud
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbssd
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbusd
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbsud
+// CHECK: amx.tilestored64
+func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_zero : vector<16x64xi8>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
+ %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
+ %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
+ return
+}
+
+// CHECK-LABEL: mulf(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpbf16ps
+// CHECK: amx.tilestored64
+func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_zero : vector<16x32xbf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+ return
+}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
new file mode 100644
index 000000000000..98b8024c194d
--- /dev/null
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: tzero
+// CHECK: amx.tile_zero : vector<16x16xbf16>
+// CHECK amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+func @tzero(%arg0: memref<?x?xbf16>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_zero : vector<16x16xbf16>
+ amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+ return
+}
+
+// CHECK-LABEL: tmulf
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
+ %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: tmuli
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
+ %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
new file mode 100644
index 000000000000..ec6038bf427b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
@@ -0,0 +1,15 @@
+import sys
+
+# AMX tests must be enabled via build flag.
+if config.mlir_run_amx_tests != 'ON':
+ config.unsupported = True
+
+# No JIT on win32.
+if sys.platform == 'win32':
+ config.unsupported = True
+
+if config.intel_sde_executable:
+ # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
+ config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
+else:
+ config.substitutions.append(('%lli', 'lli'))
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
new file mode 100644
index 000000000000..73d866af972c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+// Multiply into zeroed destination.
+func @kernel1(%arg0: memref<2x4xbf16>,
+ %arg1: memref<2x4xbf16>,
+ %arg2: memref<2x2xf32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
+ %3 = amx.tile_zero : vector<2x2xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+ return
+}
+
+// Multiply and update into destination.
+func @kernel2(%arg0: memref<2x4xbf16>,
+ %arg1: memref<2x4xbf16>,
+ %arg2: memref<2x2xf32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+ return
+}
+
+func @entry() {
+ %f0 = constant 0.0: f32
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c2 = constant 2: index
+
+ // Set up memory.
+ %a = alloc() : memref<2x4xbf16>
+ %b = alloc() : memref<2x4xbf16>
+ %c = alloc() : memref<2x2xf32>
+
+ %0 = std.constant dense<[[1.0, 2.0, 3.0, 4.0 ],
+ [5.0, 6.0, 7.0, 8.0 ]]> : vector<2x4xbf16>
+ vector.transfer_write %0, %a[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
+ %1 = std.constant dense<[[ 9.0, 10.0, 11.0, 12.0 ],
+ [13.0, 14.0, 15.0, 16.0 ]]> : vector<2x4xbf16>
+ vector.transfer_write %1, %b[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
+
+ // Call kernel.
+ call @kernel1(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
+
+ // Print and verify.
+ //
+ // CHECK: ( 124, 144 )
+ // CHECK-NEXT: ( 308, 360 )
+ scf.for %i = %c0 to %c2 step %c1 {
+ %av = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
+ vector.print %av : vector<2xf32>
+ }
+
+ // Call kernel.
+ call @kernel2(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
+
+ // Print and verify.
+ //
+ // CHECK-NEXT: ( 248, 288 )
+ // CHECK-NEXT: ( 616, 720 )
+ //
+ scf.for %i = %c0 to %c2 step %c1 {
+ %cv = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
+ vector.print %cv : vector<2xf32>
+ }
+
+ // Release resources.
+ dealloc %a : memref<2x4xbf16>
+ dealloc %b : memref<2x4xbf16>
+ dealloc %c : memref<2x2xf32>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
new file mode 100644
index 000000000000..59eff35d33cf
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+// Multiply into zeroed destination.
+func @kernel1(%arg0: memref<2x8xi8>,
+ %arg1: memref<2x8xi8>,
+ %arg2: memref<2x2xi32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
+ %3 = amx.tile_zero : vector<2x2xi32>
+ %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+ return
+}
+
+// Multiply and update into destination.
+func @kernel2(%arg0: memref<2x8xi8>,
+ %arg1: memref<2x8xi8>,
+ %arg2: memref<2x2xi32>) {
+ %0 = constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
+ %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+ return
+}
+
+func @entry() {
+ %i0 = constant 0: i32
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c2 = constant 2: index
+
+ // Set up memory.
+ %a = alloc() : memref<2x8xi8>
+ %b = alloc() : memref<2x8xi8>
+ %c = alloc() : memref<2x2xi32>
+
+ %0 = std.constant dense<[[1 , 2, 3 , 4 , 5, 6, 7, 8],
+ [9, 10, 11, 12, 13, 14, 15, 16]]> : vector<2x8xi8>
+ vector.transfer_write %0, %a[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8>
+ %1 = std.constant dense<[[17, 18, 19, 20, 21, 22, 23, 24],
+ [25, 26, 27, 28, 29, 30, 31, 32]]> : vector<2x8xi8>
+ vector.transfer_write %1, %b[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8>
+
+ // Call kernel.
+ call @kernel1(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> ()
+
+ // Print and verify.
+ //
+ // CHECK: ( 884, 1028 )
+ // CHECK-NEXT: ( 2324, 2724 )
+ scf.for %i = %c0 to %c2 step %c1 {
+ %av = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32>
+ vector.print %av : vector<2xi32>
+ }
+
+ // Call kernel.
+ call @kernel2(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> ()
+
+ // Print and verify.
+ //
+ // CHECK-NEXT: ( 1768, 2056 )
+ // CHECK-NEXT: ( 4648, 5448 )
+ //
+ scf.for %i = %c0 to %c2 step %c1 {
+ %cv = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32>
+ vector.print %cv : vector<2xi32>
+ }
+
+ // Release resources.
+ dealloc %a : memref<2x8xi8>
+ dealloc %b : memref<2x8xi8>
+ dealloc %c : memref<2x2xi32>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
new file mode 100644
index 000000000000..f49c66e4ce4b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+func @tilezero(%arg0: memref<?x?xi32>, %i: index, %j: index) {
+ %1 = amx.tile_zero : vector<16x16xi32>
+ amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
+ return
+}
+
+func @entry() {
+ %i0 = constant 0: i32
+ %i1 = constant 1: i32
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c3 = constant 3: index
+ %c19 = constant 19: index
+
+ // Set up memory.
+ %a = alloc(%c19, %c19) : memref<?x?xi32>
+ scf.for %i = %c0 to %c19 step %c1 {
+ scf.for %j = %c0 to %c19 step %c1 {
+ store %i1, %a[%i, %j] : memref<?x?xi32>
+ }
+ }
+
+ // Call kernel.
+ call @tilezero(%a, %c1, %c1) : (memref<?x?xi32>, index, index) -> ()
+
+ // Print and verify that the tilezero is correctly strided within
+ // the enveloping 19x19 buffer.
+ //
+ // CHECK: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ //
+ scf.for %i = %c0 to %c19 step %c1 {
+ %av = vector.transfer_read %a[%i, %c0], %i0: memref<?x?xi32>, vector<19xi32>
+ vector.print %av : vector<19xi32>
+ }
+
+ // Call kernel with
diff erent indices.
+ call @tilezero(%a, %c0, %c3) : (memref<?x?xi32>, index, index) -> ()
+
+ // Print and verify that the tilezero is again correctly strided
+ // within the enveloping 19x19 buffer.
+ //
+ // CHECK-NEXT: ( 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+ // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+ // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ //
+ scf.for %i = %c0 to %c19 step %c1 {
+ %av = vector.transfer_read %a[%i, %c0], %i0: memref<?x?xi32>, vector<19xi32>
+ vector.print %av : vector<19xi32>
+ }
+
+ // Release resources.
+ dealloc %a : memref<?x?xi32>
+
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
new file mode 100644
index 000000000000..d1f3cd6ce30a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @target(i8* %0)
+// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16)
+// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %0, i64 32, x86_amx %[[c]]
+llvm.func @target(%ptr: !llvm.ptr<i8>) {
+ %c = llvm.mlir.constant(16 : i16) : i16
+ %s = llvm.mlir.constant(32 : i64) : i64
+ %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>>
+ "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr<i8>, i64, !llvm.array<16 x vector<16xbf16>>) -> ()
+ llvm.return
+}
+
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 1d98d94d4b8f..0015c1369d7a 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -48,6 +48,7 @@ config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@
config.mlir_integration_test_dir = "@MLIR_INTEGRATION_TEST_DIR@"
config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@"
+config.mlir_run_amx_tests = "@MLIR_RUN_AMX_TESTS@"
config.mlir_run_avx512_tests = "@MLIR_RUN_AVX512_TESTS@"
config.mlir_include_integration_tests = "@MLIR_INCLUDE_INTEGRATION_TESTS@"
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 508a22063339..a4d8835d7da0 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -2,6 +2,7 @@
// CHECK: Available Dialects:
// CHECK-NEXT: acc
// CHECK-NEXT: affine
+// CHECK-NEXT: amx
// CHECK-NEXT: arm_neon
// CHECK-NEXT: arm_sve
// CHECK-NEXT: async
More information about the Mlir-commits
mailing list