[Mlir-commits] [mlir] 6ad7b97 - [mlir][amx] Add Intel AMX dialect (architectural-specific vector dialect)

Aart Bik llvmlistbot at llvm.org
Mon Mar 15 17:59:16 PDT 2021


Author: Aart Bik
Date: 2021-03-15T17:59:05-07:00
New Revision: 6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b

URL: https://github.com/llvm/llvm-project/commit/6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b
DIFF: https://github.com/llvm/llvm-project/commit/6ad7b97e20c22fcfbcf95561a87f5876e3bd1d1b.diff

LOG: [mlir][amx] Add Intel AMX dialect (architectural-specific vector dialect)

The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA). This new MLIR dialect
provides a bridge between MLIR concepts like vectors and memrefs
and the lower level LLVM IR details of AMX.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D98470

Added: 
    mlir/include/mlir/Dialect/AMX/AMX.td
    mlir/include/mlir/Dialect/AMX/AMXDialect.h
    mlir/include/mlir/Dialect/AMX/CMakeLists.txt
    mlir/include/mlir/Dialect/AMX/Transforms.h
    mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
    mlir/lib/Dialect/AMX/CMakeLists.txt
    mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
    mlir/lib/Dialect/AMX/IR/CMakeLists.txt
    mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
    mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
    mlir/test/Dialect/AMX/invalid.mlir
    mlir/test/Dialect/AMX/legalize-for-llvm.mlir
    mlir/test/Dialect/AMX/roundtrip.mlir
    mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
    mlir/test/Target/LLVMIR/amx.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/Target/LLVMIR/Dialect/All.h
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/CMakeLists.txt
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
    mlir/test/CMakeLists.txt
    mlir/test/lit.site.cfg.py.in
    mlir/test/mlir-opt/commandline.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f2e3fc3d3d24..79af5b3d0e15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -502,7 +502,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AVX512, ArmNeon, ArmSVE, etc.) in combination with the
+    (AMX, AVX512, ArmNeon, ArmSVE, etc.) in combination with the
     architectural-neutral vector dialect lowering.
 
   }];
@@ -517,6 +517,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
            "bool", /*default=*/"true",
            "Allows compiler to assume indices fit in 32-bit if that yields "
 	   "faster code">,
+    Option<"enableAMX", "enable-amx",
+           "bool", /*default=*/"false",
+           "Enables the use of AMX dialect while lowering the vector "
+	   "dialect.">,
     Option<"enableAVX512", "enable-avx512",
            "bool", /*default=*/"false",
            "Enables the use of AVX512 dialect while lowering the vector "

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 8d24803eeb1c..91ded03f84b0 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,8 @@ class OperationPass;
 struct LowerVectorToLLVMOptions {
   LowerVectorToLLVMOptions()
       : reassociateFPReductions(false), enableIndexOptimizations(true),
-        enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
+        enableArmNeon(false), enableArmSVE(false), enableAMX(false),
+        enableAVX512(false) {}
 
   LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
     reassociateFPReductions = b;
@@ -41,6 +42,10 @@ struct LowerVectorToLLVMOptions {
     enableArmSVE = b;
     return *this;
   }
+  LowerVectorToLLVMOptions &setEnableAMX(bool b) {
+    enableAMX = b;
+    return *this;
+  }
   LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
     enableAVX512 = b;
     return *this;
@@ -50,6 +55,7 @@ struct LowerVectorToLLVMOptions {
   bool enableIndexOptimizations;
   bool enableArmNeon;
   bool enableArmSVE;
+  bool enableAMX;
   bool enableAVX512;
 };
 

diff  --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
new file mode 100644
index 000000000000..710387e70b55
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -0,0 +1,292 @@
+//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the AMX dialect.
+//
+// The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
+// multiply unit (TMUL), a tile control register (TILECFG), and eight
+// tile registers TMM0 through TMM7 (TILEDATA).
+//
+// The AMX dialect provides a bridge between MLIR concepts, such as
+// 2-d vector, operations, and memrefs, and the lower level details
+// of Intel AMX, such as configuration setup, tile sizes, instructions,
+// and tile release.
+//
+// Note that since configuration changes (implicit at dialect level) are
+// costly, it is highly recommended to use the AMX dialect on same-shaped
+// vectors, at least within a single method.
+//
+// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX
+#define AMX
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// AMX dialect definition.
+//===----------------------------------------------------------------------===//
+
+def AMX_Dialect : Dialect {
+  let name = "amx";
+  let cppNamespace = "::mlir::amx";
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Op and IntrOp definitions.
+//===----------------------------------------------------------------------===//
+
+class AMX_Op<string mnemonic, list<OpTrait> traits = []> :
+  Op<AMX_Dialect, mnemonic, traits> {}
+
+// The "internal" intrinsics are meant for compiler usage.
+class AMX_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
+                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
+                  [], [], traits, numResults>;
+
+//===----------------------------------------------------------------------===//
+// AMX Op definitions (user facing).
+//===----------------------------------------------------------------------===//
+
+//
+// Tile reset.
+//
+
+def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
+  let summary = "tile zero operation";
+  let description = [{
+    Zeroes the destination tile, with the shape defined by the 2-dim
+    vector type of the result. This is eventually lowered into the
+    "tilezero" instruction with the corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = amx.tilezero : vector<16x16xbf16>
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let results = (outs
+    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return res().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "attr-dict `:` type($res)";
+}
+
+//
+// Tile memory operations.
+//
+
+def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
+  let summary = "tile load operation";
+  let description = [{
+    Loads a tile from memory defined by a base and indices, with the
+    shape defined by the 2-dim vector type of the result. This is
+    eventually lowered into the "tileloadd" instruction with the
+    corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = amx.tileload %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
+                   Variadic<Index>:$indices);
+  let results = (outs
+    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getVectorType() {
+      return res().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
+                       "type($base) `into` type($res)";
+}
+
+def TileStoreOp : AMX_Op<"tile_store"> {
+  let summary = "tile store operation";
+  let description = [{
+    Stores a tile to memory defined by a base and indices, with the
+    shape defined by the 2-dim vector type of the value. This is
+    eventually lowered into the "tilestored" instruction with the
+    corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      amx.tilestore %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
+                   Variadic<Index>:$indices,
+                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getVectorType() {
+      return val().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
+                       "type($base) `,` type($val)";
+}
+
+//
+// Tile arithmetic operations.
+//
+
+def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
+  let summary = "tile multiplication operation (floating-point)";
+  let description = [{
+    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
+    pairs of "bf16"). The operation is eventually lowered into the
+    "tdpbf16ps" instruction with the corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = amx.tilemulf %a, %b, %c
+        : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
+                       VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
+                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
+  let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+  let extraClassDeclaration = [{
+    VectorType getLhsVectorType() {
+      return lhs().getType().cast<VectorType>();
+    }
+    VectorType getRhsVectorType() {
+      return rhs().getType().cast<VectorType>();
+    }
+    VectorType getVectorType() {
+      return res().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
+                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+}
+
+def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
+  let summary = "tile multiplication operation (integer)";
+  let description = [{
+    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
+    combinations (4 bytes packed into dwords in the columns of both the
+    source operand tiles; the zero or sign extension is specified with
+    the attributes). The operation is eventually lowered into one of
+    the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
+    the corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = amx.tilemuli %a, %b, %c [true, true]
+        : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+    ```
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
+                       VectorOfRankAndType<[2], [I32, I8]>:$rhs,
+                       VectorOfRankAndType<[2], [I32, I8]>:$acc,
+                       BoolArrayAttr:$zext);
+  let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+  let extraClassDeclaration = [{
+    VectorType getLhsVectorType() {
+      return lhs().getType().cast<VectorType>();
+    }
+    VectorType getRhsVectorType() {
+      return rhs().getType().cast<VectorType>();
+    }
+    VectorType getVectorType() {
+      return res().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
+                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+}
+
+//===----------------------------------------------------------------------===//
+// AMX IntrOp definitions (LLVM compiler facing).
+//===----------------------------------------------------------------------===//
+
+//
+// Tile reset. Parameters define the tile size.
+//
+
+def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
+  Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>;
+
+//
+// Tile memory operations. Parameters define the tile size,
+// base address, and stride between consecutive rows for the
+// memory operation.
+//
+
+def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>;
+
+def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>;
+
+//
+// Tile multiplication operations (series of dot products). Parameters
+// define the tile sizes and source and destination tiles for the
+// operation. Note that the prefix "tdp" stands for tile dot product.
+//
+
+// Dot product of bf16 tiles into f32 tile.
+def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger,
+		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/sign extension).
+def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger,
+		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/zero extension).
+def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger,
+		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/sign extension).
+def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger,
+		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/zero extension).
+def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
+  Arguments<(ins LLVM_AnyInteger,
+                 LLVM_AnyInteger,
+		 LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+#endif // AMX

diff  --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
new file mode 100644
index 000000000000..8439c2af82c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -0,0 +1,26 @@
+//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for AMX in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_
+#define MLIR_DIALECT_AMX_AMXDIALECT_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/AMX.h.inc"
+
+#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_

diff  --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..4317fd84ac14
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(AMX amx)
+add_mlir_doc(AMX -gen-dialect-doc AMX Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS AMX.td)
+mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRAMXConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
new file mode 100644
index 000000000000..11b3004292d4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -0,0 +1,29 @@
+//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
+#define MLIR_DIALECT_AMX_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
+/// intrinsics.
+void populateAMXLegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Configure the target to support lowering AMX ops to ops that map to LLVM
+/// intrinsics.
+void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 8abad863ba32..20bccbfb2971 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Affine)
 add_subdirectory(Async)
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSVE)
+add_subdirectory(AMX)
 add_subdirectory(AVX512)
 add_subdirectory(Complex)
 add_subdirectory(DLTI)

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 92cb5a3c3bcc..edfd003dbe8e 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INITALLDIALECTS_H_
 #define MLIR_INITALLDIALECTS_H_
 
+#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -51,6 +52,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format off
   registry.insert<acc::OpenACCDialect,
                   AffineDialect,
+                  amx::AMXDialect,
                   arm_neon::ArmNeonDialect,
                   async::AsyncDialect,
                   avx512::AVX512Dialect,

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
new file mode 100644
index 000000000000..4525ec321219
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for AMX dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the AMX dialect and the translation from it to the LLVM IR
+/// in the given registry;
+void registerAMXDialectTranslation(DialectRegistry &registry);
+
+/// Register the AMX dialect and the translation from it in the registry
+/// associated with the given context.
+void registerAMXDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 97189cb78619..47907fde2042 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 
+#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h"
@@ -29,6 +30,7 @@ class DialectRegistry;
 /// corresponding translation interfaces.
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
+  registerAMXDialectTranslation(registry);
   registerAVX512DialectTranslation(registry);
   registerLLVMArmSVEDialectTranslation(registry);
   registerLLVMDialectTranslation(registry);

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 4e9e0861d312..f912a30803c8 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -30,7 +30,6 @@ class GPUModuleOp;
 
 namespace LLVM {
 class LLVMArmSVEDialect;
-class LLVMAVX512Dialect;
 class LLVMDialect;
 } // end namespace LLVM
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index ace51f8b71d4..9a5683a9168e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -14,6 +14,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
 
   LINK_LIBS PUBLIC
   MLIRArmNeon
+  MLIRAMX
+  MLIRAMXTransforms
   MLIRAVX512
   MLIRAVX512Transforms
   MLIRArmSVE

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 207a06e584c2..85657742413f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -13,6 +13,8 @@
 #include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/AVX512/Transforms.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -35,6 +37,7 @@ struct LowerVectorToLLVMPass
     this->enableIndexOptimizations = options.enableIndexOptimizations;
     this->enableArmNeon = options.enableArmNeon;
     this->enableArmSVE = options.enableArmSVE;
+    this->enableAMX = options.enableAMX;
     this->enableAVX512 = options.enableAVX512;
   }
   // Override explicitly to allow conditional dialect dependence.
@@ -45,6 +48,8 @@ struct LowerVectorToLLVMPass
       registry.insert<arm_neon::ArmNeonDialect>();
     if (enableArmSVE)
       registry.insert<LLVM::LLVMArmSVEDialect>();
+    if (enableAMX)
+      registry.insert<amx::AMXDialect>();
     if (enableAVX512)
       registry.insert<avx512::AVX512Dialect>();
   }
@@ -105,6 +110,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
         });
     populateArmSVEToLLVMConversionPatterns(converter, patterns);
   }
+  if (enableAMX) {
+    configureAMXLegalizeForExportTarget(target);
+    populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
+  }
   if (enableAVX512) {
     configureAVX512LegalizeForExportTarget(target);
     populateAVX512LegalizeForLLVMExportPatterns(converter, patterns);

diff  --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..9f57627c321f
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
new file mode 100644
index 000000000000..5ebef7efe213
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -0,0 +1,106 @@
+//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the AMX dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void amx::AMXDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/AMX.cpp.inc"
+      >();
+}
+
+/// Verify that AMX supports the implied tile shape.
+static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+  const unsigned kMaxRows = 16;
+  const unsigned kBitsPerRow = 64 * 8;
+  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
+  if (tp.getDimSize(0) > kMaxRows)
+    return op->emitOpError("bad row height: ") << tp.getDimSize(0);
+  if (col > kBitsPerRow || col & 0x1f)
+    return op->emitOpError("bad column width: ") << (col >> 3);
+  return success();
+}
+
+/// Verify that AMX supports the multiplication.
+static LogicalResult verifyMultShape(Operation *op, VectorType atp,
+                                     VectorType btp, VectorType ctp,
+                                     unsigned scale) {
+  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
+  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
+  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
+  if (cm != am || cn != bn || ak != bk)
+    return op->emitOpError("bad mult shape: ")
+           << cm << " x " << cn << " x " << ak;
+  return success();
+}
+
+static LogicalResult verify(amx::TileZeroOp op) {
+  return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileLoadOp op) {
+  unsigned rank = op.getMemRefType().getRank();
+  if (llvm::size(op.indices()) != rank)
+    return op.emitOpError("requires ") << rank << " indices";
+  return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileStoreOp op) {
+  unsigned rank = op.getMemRefType().getRank();
+  if (llvm::size(op.indices()) != rank)
+    return op.emitOpError("requires ") << rank << " indices";
+  return verifyTileSize(op, op.getVectorType());
+}
+
+static LogicalResult verify(amx::TileMulFOp op) {
+  VectorType aType = op.getLhsVectorType();
+  VectorType bType = op.getRhsVectorType();
+  VectorType cType = op.getVectorType();
+  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
+      failed(verifyTileSize(op, cType)) ||
+      failed(verifyMultShape(op, aType, bType, cType, 1)))
+    return failure();
+  Type ta = aType.getElementType();
+  Type tb = bType.getElementType();
+  Type tc = cType.getElementType();
+  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+    return op.emitOpError("unsupported type combination");
+  return success();
+}
+
+static LogicalResult verify(amx::TileMulIOp op) {
+  if (op.zext().size() != 2)
+    return op.emitOpError("unexpected zext length");
+  VectorType aType = op.getLhsVectorType();
+  VectorType bType = op.getRhsVectorType();
+  VectorType cType = op.getVectorType();
+  if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
+      failed(verifyTileSize(op, cType)) ||
+      failed(verifyMultShape(op, aType, bType, cType, 2)))
+    return failure();
+  Type ta = aType.getElementType();
+  Type tb = bType.getElementType();
+  Type tc = cType.getElementType();
+  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
+    return op.emitOpError("unsupported type combination");
+  return success();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/AMX.cpp.inc"

diff  --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
new file mode 100644
index 000000000000..8a3a7f892551
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRAMX
+  AMXDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX
+
+  DEPENDS
+  MLIRAMXIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIRSideEffectInterfaces
+  )

diff  --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..d7cf9e1e58cb
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRAMXTransforms
+  LegalizeForLLVMExport.cpp
+
+  DEPENDS
+  MLIRAMXConversionsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAMX
+  MLIRIR
+  MLIRLLVMIR
+  MLIRStandardToLLVM
+  )

diff  --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644
index 000000000000..6e082ce790fc
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -0,0 +1,230 @@
+//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/Transforms.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+
+namespace {
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
+                                     LLVMTypeConverter &typeConverter,
+                                     VectorType vType, Location loc) {
+  Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
+  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+  return std::make_pair(
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
+}
+
+/// Verifies if the stride matches proper tile access.
+LogicalResult verifyStride(MemRefType mType) {
+  if (mType.getRank() < 2)
+    return failure();
+  int64_t last = mType.getRank() - 1;
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
+    return failure();
+  return success();
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+Value getStride(ConversionPatternRewriter &rewriter,
+                LLVMTypeConverter &typeConverter, MemRefType mType, Value base,
+                Location loc) {
+  assert(mType.getRank() >= 2);
+  int64_t last = mType.getRank() - 1;
+  Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
+  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  if (mType.isDynamicDim(last)) {
+    // Dynamic size needs code to compute the stride at runtime.
+    MemRefDescriptor memrefDescriptor(base);
+    auto attr = rewriter.getI64IntegerAttr(bytes);
+    Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+    return rewriter.create<LLVM::MulOp>(
+        loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+  }
+  // Use direct constant for static size.
+  auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+}
+
+/// Cast any pointer to the !llvm.ptr<i8> pointer type.
+Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) {
+  auto i8Ptr =
+      LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8));
+  return rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, ptr);
+}
+
+struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
+  using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(TileZeroOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vType = op.getVectorType();
+    // Determine m x n tile sizes.
+    std::pair<Value, Value> tsz =
+        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+    // Replace operation with intrinsic.
+    Type resType = typeConverter->convertType(vType);
+    rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
+                                                       tsz.second);
+    return success();
+  }
+};
+
+struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
+  using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(TileLoadOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TileLoadOp::Adaptor adaptor(operands);
+    MemRefType mType = op.getMemRefType();
+    VectorType vType = op.getVectorType();
+    // Determine m x n tile sizes.
+    std::pair<Value, Value> tsz =
+        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+    // Determine stride.
+    if (failed(verifyStride(mType)))
+      return failure();
+    Value stride = getStride(rewriter, *getTypeConverter(), mType,
+                             adaptor.base(), op.getLoc());
+    // Replace operation with intrinsic.
+    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
+    ptr = castPtr(rewriter, op.getLoc(), ptr);
+    Type resType = typeConverter->convertType(vType);
+    rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
+        op, resType, tsz.first, tsz.second, ptr, stride);
+    return success();
+  }
+};
+
+struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
+  using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(TileStoreOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TileStoreOp::Adaptor adaptor(operands);
+    MemRefType mType = op.getMemRefType();
+    VectorType vType = op.getVectorType();
+    // Determine m x n tile sizes.
+    std::pair<Value, Value> tsz =
+        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+    // Determine stride.
+    if (failed(verifyStride(mType)))
+      return failure();
+    Value stride = getStride(rewriter, *getTypeConverter(), mType,
+                             adaptor.base(), op.getLoc());
+    // Replace operation with intrinsic.
+    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
+    ptr = castPtr(rewriter, op.getLoc(), ptr);
+    rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
+        op, tsz.first, tsz.second, ptr, stride, adaptor.val());
+    return success();
+  }
+};
+
+struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
+  using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(TileMulFOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TileMulFOp::Adaptor adaptor(operands);
+    VectorType aType = op.getLhsVectorType();
+    VectorType bType = op.getRhsVectorType();
+    VectorType cType = op.getVectorType();
+    // Determine m x n x k tile sizes.
+    std::pair<Value, Value> tsza =
+        getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
+    std::pair<Value, Value> tszb =
+        getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
+    // Replace operation with intrinsic.
+    Type resType = typeConverter->convertType(cType);
+    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+        op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+        adaptor.lhs(), adaptor.rhs());
+    return success();
+  }
+};
+
+struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
+  using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(TileMulIOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TileMulIOp::Adaptor adaptor(operands);
+    VectorType aType = op.getLhsVectorType();
+    VectorType bType = op.getRhsVectorType();
+    VectorType cType = op.getVectorType();
+    // Determine m x n x k tile sizes.
+    std::pair<Value, Value> tsza =
+        getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
+    std::pair<Value, Value> tszb =
+        getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
+    // Replace operation with intrinsic.
+    Type resType = typeConverter->convertType(cType);
+    bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
+    bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
+    if (zexta && zextb)
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+          adaptor.lhs(), adaptor.rhs());
+    else if (zexta && !zextb)
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+          adaptor.lhs(), adaptor.rhs());
+    else if (!zexta && zextb)
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+          adaptor.lhs(), adaptor.rhs());
+    else
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
+          adaptor.lhs(), adaptor.rhs());
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateAMXLegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  patterns.insert<TileZeroConversion, TileLoadConversion, TileStoreConversion,
+                  TileMulFConversion, TileMulIConversion>(converter);
+}
+
+void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
+  target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
+                    x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
+                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+  target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
+                      TileMulFOp>();
+}

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b21ef9308af2..b4824707131d 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Affine)
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSVE)
 add_subdirectory(Async)
+add_subdirectory(AMX)
 add_subdirectory(AVX512)
 add_subdirectory(Complex)
 add_subdirectory(DLTI)

diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 59b5b850afca..4f3465440784 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -37,6 +37,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
 
   LINK_LIBS PUBLIC
   MLIRArmNeonToLLVMIRTranslation
+  MLIRAMXToLLVMIRTranslation
   MLIRAVX512ToLLVMIRTranslation
   MLIRLLVMArmSVEToLLVMIRTranslation
   MLIRLLVMToLLVMIRTranslation

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
new file mode 100644
index 000000000000..f923367796c5
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
@@ -0,0 +1,55 @@
+//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the AMX dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsX86.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the AMX dialect to LLVM IR.
+class AMXDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final {
+    Operation &opInst = *op;
+#include "mlir/Dialect/AMX/AMXConversions.inc"
+
+    return failure();
+  }
+};
+} // end namespace
+
+void mlir::registerAMXDialectTranslation(DialectRegistry &registry) {
+  registry.insert<amx::AMXDialect>();
+  registry.addDialectInterface<amx::AMXDialect,
+                               AMXDialectLLVMIRTranslationInterface>();
+}
+
+void mlir::registerAMXDialectTranslation(MLIRContext &context) {
+  DialectRegistry registry;
+  registerAMXDialectTranslation(registry);
+  context.appendDialectRegistry(registry);
+}

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
new file mode 100644
index 000000000000..f7f7b583fd65
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRAMXToLLVMIRTranslation
+  AMXToLLVMIRTranslation.cpp
+
+  DEPENDS
+  MLIRAMXConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRAMX
+  MLIRLLVMIR
+  MLIRSupport
+  MLIRTargetLLVMIRExport
+  )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index b710af3bfeb7..ebf74740d06c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(ArmNeon)
+add_subdirectory(AMX)
 add_subdirectory(AVX512)
 add_subdirectory(LLVMArmSVE)
 add_subdirectory(LLVMIR)

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index e30cca13f92a..69d123d02047 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -29,6 +29,7 @@ set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
 if (MLIR_INCLUDE_INTEGRATION_TESTS)
   set(INTEL_SDE_EXECUTABLE "" CACHE STRING
       "If set, arch-specific integration tests are run with Intel SDE.")
+  option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_AVX512_TESTS "Run AVX512 tests.")
   # Passed to lit.site.cfg.py.in to set up the path where to find the libraries.
   set(MLIR_INTEGRATION_TEST_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})

diff  --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
new file mode 100644
index 000000000000..b3a7286b526a
--- /dev/null
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func @rowheight() {
+  // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
+  %0 = amx.tile_zero : vector<17x16xbf16>
+}
+
+// -----
+
+func @colwidth() {
+  // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
+  %0 = amx.tile_zero : vector<16x65xi8>
+}
+
+// -----
+
+func @col4bytemultiple() {
+  // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
+  %0 = amx.tile_zero : vector<16x5xi8>
+}
+
+// -----
+
+func @memtilesize(%arg0: memref<?x?xf32>) {
+  %0 = constant 0 : index
+  // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+}
+
+// -----
+
+func @memindexsize(%arg0: memref<?x?xf32>) {
+  %0 = constant 0 : index
+  // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+}
+
+// -----
+
+func @multsize() {
+  %0 = amx.tile_zero : vector<8x8xbf16>
+  %1 = amx.tile_zero : vector<8x8xbf16>
+  %2 = amx.tile_zero : vector<4x4xf32>
+  // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
+  %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+}
+
+// -----
+
+func @zextsize() {
+  %0 = amx.tile_zero : vector<8x8xi8>
+  %1 = amx.tile_zero : vector<8x8xi8>
+  %2 = amx.tile_zero : vector<8x8xi32>
+  // expected-error at +1 {{'amx.tile_muli' op unexpected zext length}}
+  %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
+}

diff  --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
new file mode 100644
index 000000000000..f88d83d8f311
--- /dev/null
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: muli(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpbuud
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbssd
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbusd
+// CHECK: amx.tilestored64
+// CHECK: amx.tdpbsud
+// CHECK: amx.tilestored64
+func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_zero : vector<16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
+  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
+  %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
+  return
+}
+
+// CHECK-LABEL: mulf(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpbf16ps
+// CHECK: amx.tilestored64
+func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_zero : vector<16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
new file mode 100644
index 000000000000..98b8024c194d
--- /dev/null
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: tzero
+// CHECK: amx.tile_zero : vector<16x16xbf16>
+// CHECK amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+func @tzero(%arg0: memref<?x?xbf16>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_zero : vector<16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+  return
+}
+
+// CHECK-LABEL: tmulf
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: tmuli
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
+  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
new file mode 100644
index 000000000000..ec6038bf427b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
@@ -0,0 +1,15 @@
+import sys
+
+# AMX tests must be enabled via build flag.
+if config.mlir_run_amx_tests != 'ON':
+    config.unsupported = True
+
+# No JIT on win32.
+if sys.platform == 'win32':
+    config.unsupported = True
+
+if config.intel_sde_executable:
+    # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
+    config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
+else:
+    config.substitutions.append(('%lli', 'lli'))

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
new file mode 100644
index 000000000000..73d866af972c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+// Multiply into zeroed destination.
+func @kernel1(%arg0: memref<2x4xbf16>,
+              %arg1: memref<2x4xbf16>,
+	      %arg2: memref<2x2xf32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %3 = amx.tile_zero : vector<2x2xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+  return
+}
+
+// Multiply and update into destination.
+func @kernel2(%arg0: memref<2x4xbf16>,
+              %arg1: memref<2x4xbf16>,
+	      %arg2: memref<2x2xf32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+  return
+}
+
+func @entry() {
+  %f0 = constant 0.0: f32
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c2 = constant 2: index
+
+  // Set up memory.
+  %a = alloc() : memref<2x4xbf16>
+  %b = alloc() : memref<2x4xbf16>
+  %c = alloc() : memref<2x2xf32>
+
+  %0 = std.constant dense<[[1.0, 2.0, 3.0, 4.0 ],
+                           [5.0, 6.0, 7.0, 8.0 ]]> : vector<2x4xbf16>
+  vector.transfer_write %0, %a[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
+  %1 = std.constant dense<[[ 9.0, 10.0, 11.0, 12.0 ],
+                           [13.0, 14.0, 15.0, 16.0 ]]> : vector<2x4xbf16>
+  vector.transfer_write %1, %b[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16>
+
+  // Call kernel.
+  call @kernel1(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
+
+  // Print and verify.
+  //
+  // CHECK:      ( 124, 144 )
+  // CHECK-NEXT: ( 308, 360 )
+  scf.for %i = %c0 to %c2 step %c1 {
+    %av = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
+    vector.print %av : vector<2xf32>
+  }
+
+  // Call kernel.
+  call @kernel2(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> ()
+
+  // Print and verify.
+  //
+  // CHECK-NEXT: ( 248, 288 )
+  // CHECK-NEXT: ( 616, 720 )
+  //
+  scf.for %i = %c0 to %c2 step %c1 {
+    %cv = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32>
+    vector.print %cv : vector<2xf32>
+  }
+
+  // Release resources.
+  dealloc %a : memref<2x4xbf16>
+  dealloc %b : memref<2x4xbf16>
+  dealloc %c : memref<2x2xf32>
+
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
new file mode 100644
index 000000000000..59eff35d33cf
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+// Multiply into zeroed destination.
+func @kernel1(%arg0: memref<2x8xi8>,
+              %arg1: memref<2x8xi8>,
+	      %arg2: memref<2x2xi32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %3 = amx.tile_zero : vector<2x2xi32>
+  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+  return
+}
+
+// Multiply and update into destination.
+func @kernel2(%arg0: memref<2x8xi8>,
+              %arg1: memref<2x8xi8>,
+	      %arg2: memref<2x2xi32>) {
+  %0 = constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
+  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+  return
+}
+
+func @entry() {
+  %i0 = constant 0: i32
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c2 = constant 2: index
+
+  // Set up memory.
+  %a = alloc() : memref<2x8xi8>
+  %b = alloc() : memref<2x8xi8>
+  %c = alloc() : memref<2x2xi32>
+
+  %0 = std.constant dense<[[1 , 2,  3 , 4 , 5,  6,  7,  8],
+                           [9, 10, 11, 12, 13, 14, 15, 16]]> : vector<2x8xi8>
+  vector.transfer_write %0, %a[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8>
+  %1 = std.constant dense<[[17, 18, 19, 20, 21, 22, 23, 24],
+                           [25, 26, 27, 28, 29, 30, 31, 32]]> : vector<2x8xi8>
+  vector.transfer_write %1, %b[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8>
+
+  // Call kernel.
+  call @kernel1(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> ()
+
+  // Print and verify.
+  //
+  // CHECK:      ( 884, 1028 )
+  // CHECK-NEXT: ( 2324, 2724 )
+  scf.for %i = %c0 to %c2 step %c1 {
+    %av = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32>
+    vector.print %av : vector<2xi32>
+  }
+
+  // Call kernel.
+  call @kernel2(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> ()
+
+  // Print and verify.
+  //
+  // CHECK-NEXT: ( 1768, 2056 )
+  // CHECK-NEXT: ( 4648, 5448 )
+  //
+  scf.for %i = %c0 to %c2 step %c1 {
+    %cv = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32>
+    vector.print %cv : vector<2xi32>
+  }
+
+  // Release resources.
+  dealloc %a : memref<2x8xi8>
+  dealloc %b : memref<2x8xi8>
+  dealloc %c : memref<2x2xi32>
+
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
new file mode 100644
index 000000000000..f49c66e4ce4b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+func @tilezero(%arg0: memref<?x?xi32>, %i: index, %j: index) {
+  %1 = amx.tile_zero : vector<16x16xi32>
+  amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
+  return
+}
+
+func @entry() {
+  %i0 = constant 0: i32
+  %i1 = constant 1: i32
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c3 = constant 3: index
+  %c19 = constant 19: index
+
+  // Set up memory.
+  %a = alloc(%c19, %c19) : memref<?x?xi32>
+  scf.for %i = %c0 to %c19 step %c1 {
+    scf.for %j = %c0 to %c19 step %c1 {
+      store %i1, %a[%i, %j] : memref<?x?xi32>
+    }
+  }
+
+  // Call kernel.
+  call @tilezero(%a, %c1, %c1) : (memref<?x?xi32>, index, index) -> ()
+
+  // Print and verify that the tilezero is correctly strided within
+  // the enveloping 19x19 buffer.
+  //
+  // CHECK:      ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+  //
+  scf.for %i = %c0 to %c19 step %c1 {
+    %av = vector.transfer_read %a[%i, %c0], %i0: memref<?x?xi32>, vector<19xi32>
+    vector.print %av : vector<19xi32>
+  }
+
+  // Call kernel with 
diff erent indices.
+  call @tilezero(%a, %c0, %c3) : (memref<?x?xi32>, index, index) -> ()
+
+  // Print and verify that the tilezero is again correctly strided
+  // within the enveloping 19x19 buffer.
+  //
+  // CHECK-NEXT: ( 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+  // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 )
+  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+  // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+  //
+  scf.for %i = %c0 to %c19 step %c1 {
+    %av = vector.transfer_read %a[%i, %c0], %i0: memref<?x?xi32>, vector<19xi32>
+    vector.print %av : vector<19xi32>
+  }
+
+  // Release resources.
+  dealloc %a : memref<?x?xi32>
+
+  return
+}

diff  --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
new file mode 100644
index 000000000000..d1f3cd6ce30a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @target(i8* %0)
+// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16)
+// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %0, i64 32, x86_amx %[[c]]
+llvm.func @target(%ptr: !llvm.ptr<i8>) {
+  %c = llvm.mlir.constant(16 : i16) : i16
+  %s = llvm.mlir.constant(32 : i64) : i64
+  %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>>
+  "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr<i8>, i64, !llvm.array<16 x vector<16xbf16>>) -> ()
+  llvm.return
+}
+

diff  --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 1d98d94d4b8f..0015c1369d7a 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -48,6 +48,7 @@ config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
 config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@
 config.mlir_integration_test_dir = "@MLIR_INTEGRATION_TEST_DIR@"
 config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@"
+config.mlir_run_amx_tests = "@MLIR_RUN_AMX_TESTS@"
 config.mlir_run_avx512_tests = "@MLIR_RUN_AVX512_TESTS@"
 config.mlir_include_integration_tests = "@MLIR_INCLUDE_INTEGRATION_TESTS@"
 

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 508a22063339..a4d8835d7da0 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -2,6 +2,7 @@
 // CHECK: Available Dialects:
 // CHECK-NEXT: acc
 // CHECK-NEXT: affine
+// CHECK-NEXT: amx
 // CHECK-NEXT: arm_neon
 // CHECK-NEXT: arm_sve
 // CHECK-NEXT: async


        


More information about the Mlir-commits mailing list