[Mlir-commits] [mlir] [mlir][x86] Move AMX dialect into X86 dialect (PR #183717)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Feb 27 02:14:31 PST 2026


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/183717

Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions.

The two dialects are simply merged together. X86 dialect refactoring will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.

>From f715caec449b655fba02736659b926dc4b48cf30 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 26 Feb 2026 17:30:42 +0100
Subject: [PATCH] [mlir][x86] Move AMX dialect into X86 dialect

Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into x86 in line with other x86 extensions.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
---
 mlir/Maintainers.md                           |   1 -
 mlir/docs/TargetLLVMIR.md                     |   4 +-
 mlir/include/mlir-c/Dialect/AMX.h             |  25 -
 mlir/include/mlir/Conversion/Passes.td        |  14 +-
 .../mlir/Conversion/VectorToAMX/VectorToAMX.h |   4 +-
 mlir/include/mlir/Dialect/AMX/AMX.td          | 440 ------------------
 mlir/include/mlir/Dialect/AMX/AMXDialect.h    |  34 --
 .../include/mlir/Dialect/AMX/AMXInterfaces.td |  31 --
 mlir/include/mlir/Dialect/AMX/CMakeLists.txt  |   5 -
 mlir/include/mlir/Dialect/AMX/Transforms.h    |  33 --
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 -
 .../Dialect/SparseTensor/Pipelines/Passes.h   |   5 -
 mlir/include/mlir/Dialect/X86/Transforms.h    |   7 +-
 mlir/include/mlir/Dialect/X86/X86.td          | 384 +++++++++++++++
 mlir/include/mlir/Dialect/X86/X86Dialect.h    |  13 +
 mlir/lib/CAPI/Dialect/AMX.cpp                 |  13 -
 mlir/lib/CAPI/Dialect/CMakeLists.txt          |   9 -
 .../lib/Conversion/VectorToAMX/CMakeLists.txt |   2 +-
 .../Conversion/VectorToAMX/VectorToAMX.cpp    |  62 +--
 .../Conversion/VectorToLLVM/CMakeLists.txt    |   2 -
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   8 -
 mlir/lib/Dialect/AMX/CMakeLists.txt           |   2 -
 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp        | 318 -------------
 mlir/lib/Dialect/AMX/IR/CMakeLists.txt        |  15 -
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |   9 -
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  70 ---
 mlir/lib/Dialect/CMakeLists.txt               |   1 -
 mlir/lib/Dialect/X86/IR/X86Dialect.cpp        | 285 ++++++++++++
 .../X86/Transforms/LegalizeForLLVMExport.cpp  |  25 +-
 mlir/lib/RegisterAllDialects.cpp              |   2 -
 mlir/lib/RegisterAllExtensions.cpp            |   4 +-
 mlir/test/CMakeLists.txt                      |   4 +-
 .../VectorToAMX/contract-to-amx.mlir          |  28 +-
 .../VectorToAMX/transfer-to-amx.mlir          |  18 +-
 .../pass-option-serialization.mlir            |   1 -
 mlir/test/Dialect/AMX/invalid.mlir            | 158 -------
 mlir/test/Dialect/AMX/roundtrip.mlir          |  77 ---
 mlir/test/Dialect/AMX/side-effects.mlir       |  32 --
 mlir/test/Dialect/Linalg/invalid.mlir         |  18 +-
 mlir/test/Dialect/X86/AMX/invalid.mlir        | 158 +++++++
 .../{ => X86}/AMX/legalize-for-llvm.mlir      |  68 +--
 mlir/test/Dialect/X86/AMX/roundtrip.mlir      |  77 +++
 mlir/test/Dialect/X86/AMX/side-effects.mlir   |  32 ++
 .../Vector/CPU/{ => X86}/AMX/lit.local.cfg    |   2 +-
 .../Vector/CPU/{ => X86}/AMX/mulf-full.mlir   |  12 +-
 .../Vector/CPU/{ => X86}/AMX/mulf.mlir        |  22 +-
 .../Vector/CPU/{ => X86}/AMX/muli-ext.mlir    |  42 +-
 .../Vector/CPU/{ => X86}/AMX/muli-full.mlir   |  12 +-
 .../Vector/CPU/{ => X86}/AMX/muli.mlir        |  22 +-
 .../CPU/{ => X86}/AMX/tilezero-block.mlir     |   6 +-
 .../Vector/CPU/{ => X86}/AMX/tilezero.mlir    |   6 +-
 .../Dialect/Vector/CPU/X86/dot.mlir           |   7 +-
 .../Vector/CPU/X86/sparse-dot-product.mlir    |  24 +-
 mlir/test/Target/LLVMIR/amx.mlir              |  84 ++--
 mlir/test/lit.site.cfg.py.in                  |   2 +-
 mlir/test/mlir-opt/commandline.mlir           |   1 -
 56 files changed, 1210 insertions(+), 1531 deletions(-)
 delete mode 100644 mlir/include/mlir-c/Dialect/AMX.h
 delete mode 100644 mlir/include/mlir/Dialect/AMX/AMX.td
 delete mode 100644 mlir/include/mlir/Dialect/AMX/AMXDialect.h
 delete mode 100644 mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
 delete mode 100644 mlir/include/mlir/Dialect/AMX/CMakeLists.txt
 delete mode 100644 mlir/include/mlir/Dialect/AMX/Transforms.h
 delete mode 100644 mlir/lib/CAPI/Dialect/AMX.cpp
 delete mode 100644 mlir/lib/Dialect/AMX/CMakeLists.txt
 delete mode 100644 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
 delete mode 100644 mlir/lib/Dialect/AMX/IR/CMakeLists.txt
 delete mode 100644 mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
 delete mode 100644 mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
 delete mode 100644 mlir/test/Dialect/AMX/invalid.mlir
 delete mode 100644 mlir/test/Dialect/AMX/roundtrip.mlir
 delete mode 100644 mlir/test/Dialect/AMX/side-effects.mlir
 create mode 100644 mlir/test/Dialect/X86/AMX/invalid.mlir
 rename mlir/test/Dialect/{ => X86}/AMX/legalize-for-llvm.mlir (64%)
 create mode 100644 mlir/test/Dialect/X86/AMX/roundtrip.mlir
 create mode 100644 mlir/test/Dialect/X86/AMX/side-effects.mlir
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/lit.local.cfg (91%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/mulf-full.mlir (95%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/mulf.mlir (74%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/muli-ext.mlir (83%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/muli-full.mlir (95%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/muli.mlir (74%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/tilezero-block.mlir (94%)
 rename mlir/test/Integration/Dialect/Vector/CPU/{ => X86}/AMX/tilezero.mlir (96%)

diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md
index a023ee0ea1bba..181541a0f3a93 100644
--- a/mlir/Maintainers.md
+++ b/mlir/Maintainers.md
@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
 * ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
-* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
 
diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index 2bdf400a7759f..2bcacbb4ee946 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -5,8 +5,8 @@ overall flow is two-stage:
 
 1.  **conversion** of the IR to a set of dialects translatable to LLVM IR, for
     example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
-    dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
-    [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
+    dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
+    or [ArmNeon](Dialects/ArmNeon.md);
 2.  **translation** of MLIR dialects to LLVM IR.
 
 This flow allows the non-trivial transformation to be performed within MLIR
diff --git a/mlir/include/mlir-c/Dialect/AMX.h b/mlir/include/mlir-c/Dialect/AMX.h
deleted file mode 100644
index ac4695a107ae6..0000000000000
--- a/mlir/include/mlir-c/Dialect/AMX.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM
-// Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_C_DIALECT_AMX_H
-#define MLIR_C_DIALECT_AMX_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // MLIR_C_DIALECT_AMX_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ecc22abb0f935..e77860897399f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
-    architectural-neutral vector dialect lowering.
+    (X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
+    vector dialect lowering.
 
   }];
   // Override explicitly in C++ to allow conditional dialect dependence.
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "vector access are naturally aligned. If operations have an "
            "alignment attribute set, the alignment attribute takes priority "
            "over this option ">,
-    Option<"amx", "enable-amx",
-           "bool", /*default=*/"false",
-           "Enables the use of AMX dialect while lowering the vector "
-	   "dialect.">,
     Option<"armNeon", "enable-arm-neon",
            "bool", /*default=*/"false",
            "Enables the use of ArmNeon dialect while lowering the vector "
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
-  let summary = "Lower the operations from the vector dialect into the AMX "
-                "dialect";
+  let summary = "Lower the operations from the vector dialect into the X86 "
+                "dialect AMX operations";
   let dependentDialects = [
-    "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+    "affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
     "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
   ];
 }
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
index b075ac92990a2..6b178e02684c0 100644
--- a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -1,4 +1,4 @@
-//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -18,7 +18,7 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTVECTORTOAMX
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Collect a set of patterns to convert from the vector to AMX ops.
+/// Collect a set of patterns to convert from the vector to X86 AMX ops.
 void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
deleted file mode 100644
index cace63d32fd80..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ /dev/null
@@ -1,440 +0,0 @@
-//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the AMX dialect.
-//
-// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-// multiply unit (TMUL), a tile control register (TILECFG), and eight
-// tile registers TMM0 through TMM7 (TILEDATA).
-//
-// The AMX dialect provides a bridge between MLIR concepts, such as
-// 2-d vector, operations, and memrefs, and the lower level details
-// of Intel AMX, such as configuration setup, tile sizes, instructions,
-// and tile release.
-//
-// Note that since configuration changes (implicit at dialect level) are
-// costly, it is highly recommended to use the AMX dialect on same-shaped
-// vectors, at least within a single method.
-//
-// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX
-#define AMX
-
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/Dialect/AMX/AMXInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/BuiltinTypes.td"
-
-//===----------------------------------------------------------------------===//
-// AMX dialect definition.
-//===----------------------------------------------------------------------===//
-
-def AMX_Dialect : Dialect {
-  let name = "amx";
-  let cppNamespace = "::mlir::amx";
-  let description = [{
-    The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-    multiply unit (TMUL), a tile control register (TILECFG), and eight
-    tile registers TMM0 through TMM7 (TILEDATA).
-
-    This `AMX` dialect provides a bridge between MLIR concepts such as
-    vectors and memrefs and the lower level LLVM IR support of AMX.
-
-    Note that since configuration changes (implicit at dialect level) are
-    costly, it is highly recommended to use the AMX dialect on same-shaped
-    vectors, at least within a single method.
-
-    For details, see the Intel documentation:
-    https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-  }];
-  let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// AMX Tile definition.
-//===----------------------------------------------------------------------===//
-
-class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
-    : TypeDef<AMX_Dialect, typeName, traits> {
-  let mnemonic = typeMnemonic;
-}
-
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
-  let cppFunctionName = "isValidTileTypeElementType";
-}
-
-def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
-  let summary = "AMX 2D tile to be used by AMX opertaions.";
-
-  let description = [{
-    This type is used to represent values in AMX tile registers. All AMX operations
-    work on AMX tiles and these tiles cannot be used in other operations directly.
-    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
-    element type for IR verification and lowering to LLVMIR dialect.
-  }];
-
-  let parameters = (ins
-    ArrayRefParameter<"int64_t">:$shape,
-    AMX_TileTypeElementType:$elementType
-  );
-
-  let builders = [
-    TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
-      return $_get(elementType.getContext(), shape, elementType);
-    }]>
-  ];
-
-  let extraClassDeclaration = [{
-    /// Returns if this type is ranked (always true).
-    bool hasRank() const { return true; }
-
-    /// Clone this tile type with the given shape and element type. If the
-    /// provided shape is `std::nullopt`, the current shape of the type is used.
-    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const {
-      return get(shape.value_or(getShape()), elementType);
-    }
-  }];
-
-  let hasCustomAssemblyFormat = 1;
-  let skipDefaultBuilders = 1;
-}
-
-def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
-  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-
-class AMXTileOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
-                      "::mlir::amx::TileType">;
-
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
-
-def AMXTileF32 : AMXTileOf<[F32]>;
-
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
-
-def AMXTileI32 : AMXTileOf<[I32]>;
-
-def AMXTileI8 : AMXTileOf<[I8]>;
-
-//===----------------------------------------------------------------------===//
-// AMX Op and IntrOp definitions.
-//===----------------------------------------------------------------------===//
-
-class AMX_Op<string mnemonic, list<Trait> traits = []> :
-  Op<AMX_Dialect, mnemonic, traits> {}
-
-//===----------------------------------------------------------------------===//
-// AMX Op definitions
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset.
-//
-
-def TileZeroOp : AMX_Op<"tile_zero", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>
-  ]> {
-  let summary = "tile zero operation";
-  let description = [{
-    Zeroes the destination tile, with the shape defined by the 2-dim
-    vector type of the result.
-    
-    The operation is eventually lowered into the "tilezero" instruction
-    with the corresponding tile configuration.
-    
-    With the write memory effect, each `amx.tile_zero` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
-    ```
-  }];
-  let results = (outs AnyAMXTile:$res);
-  let extraClassDeclaration = [{
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilezero.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "attr-dict `:` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile memory operations.
-//
-
-def TileLoadOp : AMX_Op<"tile_load", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile load operation";
-  let description = [{
-    Loads a tile from memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the result.
-    The tile's rows are populated by reading contiguous elements starting
-    at the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is loaded using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        tile[row, col] = mem_row[col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tileloadd" instruction
-    with the corresponding tile configuration.
-
-    With the write memory effect, each `amx.tile_load` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      // Tile load from a 2-D memref with implicit stride.
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-
-      // Tile load from a 1-D memref with explicit stride.
-      %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices,
-                   Optional<Index>:$stride);
-  let results = (outs AnyAMXTile:$res);
-  let builders = [
-    OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tileloadd64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
-                       "`:` type($base) `into` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-def TileStoreOp : AMX_Op<"tile_store", [
-    AMXIntrinsicOpInterface,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile store operation";
-  let description = [{
-    Stores a tile to memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the value.
-    The tile's rows are written contiguously to the buffer starting at
-    the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is stored using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        mem_row[col] = tile[row, col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tilestored" instruction
-    with the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      // Tile store to a 2-D memref with implicit stride.
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
-
-      // Tile store to a 1-D memref with explicit stride.
-      amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
-                   Variadic<Index>:$indices,
-                   AnyAMXTile:$val,
-                   Optional<Index>:$stride);
-  let builders = [
-    OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getVal().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilestored64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
-                       "attr-dict `:` type($base) `,` qualified(type($val))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile arithmetic operations.
-//
-
-def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (floating-point)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16").
-    
-    The operation is eventually lowered into the "tdpbf16ps" instruction with
-    the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-    ```
-  }];
-  let arguments = (ins AMXTileF16OrBF16:$lhs,
-                       AMXTileF16OrBF16:$rhs,
-                       AMXTileF32:$acc);
-  let results = (outs AMXTileF32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsTileType() {
-      return ::llvm::cast<TileType>(getRhs().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.tdp";
-      auto elementType =
-        getLhsTileType().getElementType();
-      intr += elementType.isF16() ? "fp16" : "bf16";
-      intr += "ps.internal";
-      return intr;
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "qualified(type($lhs)) `,` qualified(type($rhs))"
-                       " `,` qualified(type($acc)) ";
-  let hasVerifier = 1;
-}
-
-def TileMulIOp : AMX_Op<"tile_muli", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (integer)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
-    combinations (4 bytes packed into dwords in the columns of both the
-    source operand tiles; the zero or sign extension is specified with
-    the attributes and default to sign extended).
-    
-    The operation is eventually lowered into one of the "tdpbssd",
-    "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
-    tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-    ```
-  }];
-  let arguments = (ins AMXTileI8:$lhs,
-                       AMXTileI8:$rhs,
-                       AMXTileI32:$acc,
-                       UnitAttr:$isZextLhs,
-                       UnitAttr:$isZextRhs
-                       );
-  let results = (outs AMXTileI32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsTileType() {
-      return ::llvm::cast<TileType>(getRhs().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.tdpb";
-      intr += getIsZextLhs() ? "u" : "s";
-      intr += getIsZextRhs() ? "u" : "s";
-      intr += "d.internal";
-      return intr;
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
-                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
-  let hasVerifier = 1;
-}
-
-#endif // AMX
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
deleted file mode 100644
index c79f31d4c994a..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file declares the Target dialect for AMX in MLIR.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_
-#define MLIR_DIALECT_AMX_AMXDIALECT_H_
-
-#include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-
-/// Include the generated interface declarations.
-#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
-
-#include "mlir/Dialect/AMX/AMXDialect.h.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/AMX/AMXTypes.h.inc"
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/AMX.h.inc"
-
-#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_
diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
deleted file mode 100644
index 012d1ba7368f7..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines interfaces for the AMX dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX_INTERFACES
-#define AMX_INTERFACES
-
-include "mlir/IR/Interfaces.td"
-include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
-
-//===----------------------------------------------------------------------===//
-// AMX Intrinsic Interface
-//===----------------------------------------------------------------------===//
-
-def AMXIntrinsicOpInterface
-    : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
-  let description = [{
-    A wrapper interface for operations representing AMX LLVM intrinsics.
-  }];
-  let cppNamespace = "::mlir::amx";
-}
-
-#endif // AMX_INTERFACES
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
deleted file mode 100644
index f875c78d240cc..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-add_mlir_dialect(AMX amx)
-add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
-
-add_mlir_interface(AMXInterfaces)
-add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
deleted file mode 100644
index 7391ec2ff6b14..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ /dev/null
@@ -1,33 +0,0 @@
-//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
-#define MLIR_DIALECT_AMX_TRANSFORMS_H
-
-namespace mlir {
-
-class LLVMConversionTarget;
-class LLVMTypeConverter;
-class RewritePatternSet;
-class DialectRegistry;
-
-/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
-/// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
-                                              RewritePatternSet &patterns);
-
-/// Configure the target to support lowering AMX ops to ops that map to LLVM
-/// intrinsics.
-void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
-
-/// Register LLVM conversion interface for AMX dialect.
-void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index ae9a18046c101..d2505877e2dd0 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_subdirectory(Affine)
 add_subdirectory(AMDGPU)
-add_subdirectory(AMX)
 add_subdirectory(Arith)
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 6d1d630056627..2e76985e92e19 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -104,10 +104,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
       desc("Allows compiler to assume indices fit in 32-bit if that yields "
            "faster code"),
       init(true)};
-  PassOptions::Option<bool> amx{
-      *this, "enable-amx",
-      desc("Enables the use of AMX dialect while lowering the vector dialect"),
-      init(false)};
   PassOptions::Option<bool> armNeon{
       *this, "enable-arm-neon",
       desc("Enables the use of ArmNeon dialect while lowering the vector "
@@ -168,7 +164,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
     opts.force32BitVectorIndices = force32BitVectorIndices;
     opts.armNeon = armNeon;
     opts.armSVE = armSVE;
-    opts.amx = amx;
     opts.x86 = x86;
     return opts;
   }
diff --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h
index 7ab3a0b0b5629..2862e83f06f79 100644
--- a/mlir/include/mlir/Dialect/X86/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86/Transforms.h
@@ -200,13 +200,16 @@ void populateSpecializedTransposeLoweringPatterns(
 
 /// Collect a set of patterns to lower X86 ops to ops that map to LLVM
 /// intrinsics.
-void populateX86LegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns);
 
 /// Configure the target to support lowering X86 ops to ops that map to
 /// LLVM intrinsics.
 void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for X86 dialect.
+void registerConvertX86ToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_X86_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 8b5973985a4b2..e8965d04c2145 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Dialect/X86/X86Interfaces.td"
+include "mlir/IR/BuiltinTypes.td"
 
 //===----------------------------------------------------------------------===//
 // X86 dialect definition
@@ -25,6 +26,8 @@ include "mlir/Dialect/X86/X86Interfaces.td"
 def X86_Dialect : Dialect {
   let name = "x86";
   let cppNamespace = "::mlir::x86";
+
+  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -673,4 +676,385 @@ def CvtPackedOddIndexedToF32Op
         ::mlir::RewriterBase &rewriter);
   }];
 }
+
+//===----------------------------------------------------------------------===//
+// AMX Tile definition
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<X86_Dialect, "AMX" # typeName, traits> {
+  let mnemonic = "amx." # typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+  let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+  let description = [{
+    This type is used to represent values in AMX tile registers. All AMX operations
+    work on AMX tiles and these tiles cannot be used in other operations directly.
+    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+    element type for IR verification and lowering to LLVMIR dialect.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    AMX_TileTypeElementType:$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this tile type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    AMXTileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+      return get(shape.value_or(getShape()), elementType);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::x86::AMXTileType>($_self)">,
+  CPred<[{::llvm::cast<::mlir::x86::AMXTileType>($_self).getRank() == 2}]>]>;
+
+class AMXTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
+                      "::mlir::x86::AMXTileType">;
+
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
+//===----------------------------------------------------------------------===//
+// AMX Op definitions
+//===----------------------------------------------------------------------===//
+
+class AMX_Op<string mnemonic, list<Trait> traits = []>
+    : Op<X86_Dialect, "amx." # mnemonic, traits> {
+  let cppNamespace = X86_Dialect.cppNamespace # "::amx";
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Tile Zero
+//===----------------------------------------------------------------------===//
+
+def TileZeroOp : AMX_Op<"tile_zero", [
+    X86IntrinsicOpInterface,
+    MemoryEffects<[MemWrite]>
+  ]> {
+  let summary = "tile zero operation";
+  let description = [{
+    Zeroes the destination tile, with the shape defined by the 2-dim
+    vector type of the result.
+    
+    The operation is eventually lowered into the "tilezero" instruction
+    with the corresponding tile configuration.
+    
+    With the write memory effect, each `x86.amx.tile_zero` operation serves as
+    a compilation hint to use a separate tile register.
+
+    Example:
+
+    ```mlir
+      %0 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
+    ```
+  }];
+  let results = (outs AnyAMXTile:$res);
+  let extraClassDeclaration = [{
+    AMXTileType getTileType() {
+      return ::llvm::cast<AMXTileType>(getRes().getType());
+    }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilezero.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+  let assemblyFormat = "attr-dict `:` qualified(type($res))";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Tile Load
+//===----------------------------------------------------------------------===//
+
+def TileLoadOp : AMX_Op<"tile_load", [
+    X86IntrinsicOpInterface,
+    MemoryEffects<[MemWrite]>,
+    AttrSizedOperandSegments
+  ]> {
+  let summary = "tile load operation";
+  let description = [{
+    Loads a tile from memory defined by a `base` and `indices`, with the
+    shape defined by the 2-dim vector type of the result.
+    The tile's rows are populated by reading contiguous elements starting
+    at the `base`. For each tile row, the `base` is incremented by `stride`
+    number of elements.
+
+    The tile is loaded using the following indexing scheme:
+
+    ```
+    for row in enumerate(tile_rows):
+      mem_row = base[i0, i1, ..., iN + row * stride]
+      for col in enumerate(tile_cols):
+        tile[row, col] = mem_row[col]
+    ```
+
+    If the `stride` is not provided, then the `base` buffer must be at least
+    2-dimensional, and the `stride` is automatically inferred and corresponds
+    to the stride of the buffer's second innermost dimension.
+
+    The operation is eventually lowered into the "tileloadd" instruction
+    with the corresponding tile configuration.
+
+    With the write memory effect, each `x86.amx.tile_load` operation serves as
+    a compilation hint to use a separate tile register.
+
+    Example:
+
+    ```mlir
+      // Tile load from a 2-D memref with implicit stride.
+      %0 = x86.amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+
+      // Tile load from a 1-D memref with explicit stride.
+      %0 = x86.amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !x86.amx.tile<16x64xi8>
+    ```
+  }];
+  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
+                   Variadic<Index>:$indices,
+                   Optional<Index>:$stride);
+  let results = (outs AnyAMXTile:$res);
+  let builders = [
+    OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
+  ];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return ::llvm::cast<MemRefType>(getBase().getType());
+    }
+    AMXTileType getTileType() {
+      return ::llvm::cast<AMXTileType>(getRes().getType());
+    }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tileloadd64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+  let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
+                       "`:` type($base) `into` qualified(type($res))";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Tile Store
+//===----------------------------------------------------------------------===//
+
+def TileStoreOp : AMX_Op<"tile_store", [
+    X86IntrinsicOpInterface,
+    AttrSizedOperandSegments
+  ]> {
+  let summary = "tile store operation";
+  let description = [{
+    Stores a tile to memory defined by a `base` and `indices`, with the
+    shape defined by the 2-dim vector type of the value.
+    The tile's rows are written contiguously to the buffer starting at
+    the `base`. For each tile row, the `base` is incremented by `stride`
+    number of elements.
+
+    The tile is stored using the following indexing scheme:
+
+    ```
+    for row in enumerate(tile_rows):
+      mem_row = base[i0, i1, ..., iN + row * stride]
+      for col in enumerate(tile_cols):
+        mem_row[col] = tile[row, col]
+    ```
+
+    If the `stride` is not provided, then the `base` buffer must be at least
+    2-dimensional, and the `stride` is automatically inferred and corresponds
+    to the stride of the buffer's second innermost dimension.
+
+    The operation is eventually lowered into the "tilestored" instruction
+    with the corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      // Tile store to a 2-D memref with implicit stride.
+      x86.amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
+
+      // Tile store to a 1-D memref with explicit stride.
+      x86.amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !x86.amx.tile<16x64xi8>
+    ```
+  }];
+  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
+                   Variadic<Index>:$indices,
+                   AnyAMXTile:$val,
+                   Optional<Index>:$stride);
+  let builders = [
+    OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
+  ];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return ::llvm::cast<MemRefType>(getBase().getType());
+    }
+    AMXTileType getTileType() {
+      return ::llvm::cast<AMXTileType>(getVal().getType());
+    }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilestored64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+  let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
+                       "attr-dict `:` type($base) `,` qualified(type($val))";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AMX Tile Multiply
+//===----------------------------------------------------------------------===//
+
+def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
+  let summary = "tile multiplication operation (floating-point)";
+  let description = [{
+    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
+    pairs of "bf16").
+    
+    The operation is eventually lowered into the "tdpbf16ps" instruction with
+    the corresponding tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = x86.amx.tile_mulf %a, %b, %c
+        : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+    ```
+  }];
+  let arguments = (ins AMXTileF16OrBF16:$lhs,
+                       AMXTileF16OrBF16:$rhs,
+                       AMXTileF32:$acc);
+  let results = (outs AMXTileF32:$res);
+  let extraClassDeclaration = [{
+    AMXTileType getLhsTileType() {
+      return ::llvm::cast<AMXTileType>(getLhs().getType());
+    }
+    AMXTileType getRhsTileType() {
+      return ::llvm::cast<AMXTileType>(getRhs().getType());
+    }
+    AMXTileType getTileType() {
+      return ::llvm::cast<AMXTileType>(getRes().getType());
+    }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdp";
+      auto elementType =
+        getLhsTileType().getElementType();
+      intr += elementType.isF16() ? "fp16" : "bf16";
+      intr += "ps.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
+                       "qualified(type($lhs)) `,` qualified(type($rhs))"
+                       " `,` qualified(type($acc)) ";
+  let hasVerifier = 1;
+}
+
+def TileMulIOp : AMX_Op<"tile_muli", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
+  let summary = "tile multiplication operation (integer)";
+  let description = [{
+    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
+    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
+    combinations (4 bytes packed into dwords in the columns of both the
+    source operand tiles; the zero or sign extension is specified with
+    the attributes and default to sign extended).
+    
+    The operation is eventually lowered into one of the "tdpbssd",
+    "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
+    tile configuration.
+
+    Example:
+
+    ```mlir
+      %0 = x86.amx.tile_muli %a zext, %b zext, %c
+        : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+    ```
+  }];
+  let arguments = (ins AMXTileI8:$lhs,
+                       AMXTileI8:$rhs,
+                       AMXTileI32:$acc,
+                       UnitAttr:$isZextLhs,
+                       UnitAttr:$isZextRhs
+                       );
+  let results = (outs AMXTileI32:$res);
+  let extraClassDeclaration = [{
+    AMXTileType getLhsTileType() {
+      return ::llvm::cast<AMXTileType>(getLhs().getType());
+    }
+    AMXTileType getRhsTileType() {
+      return ::llvm::cast<AMXTileType>(getRhs().getType());
+    }
+    AMXTileType getTileType() {
+      return ::llvm::cast<AMXTileType>(getRes().getType());
+    }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdpb";
+      intr += getIsZextLhs() ? "u" : "s";
+      intr += getIsZextRhs() ? "u" : "s";
+      intr += "d.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
+                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
+  let hasVerifier = 1;
+}
+
 #endif // X86_OPS
diff --git a/mlir/include/mlir/Dialect/X86/X86Dialect.h b/mlir/include/mlir/Dialect/X86/X86Dialect.h
index dbce51e641158..6b1358b31e666 100644
--- a/mlir/include/mlir/Dialect/X86/X86Dialect.h
+++ b/mlir/include/mlir/Dialect/X86/X86Dialect.h
@@ -29,6 +29,19 @@
 
 #include "mlir/Dialect/X86/X86Dialect.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/X86/X86Types.h.inc"
+
+namespace mlir {
+namespace x86 {
+namespace amx {
+// Alias to allow access to AMX type through nested namespaces
+// analogously to AMX operations.
+using TileType = mlir::x86::AMXTileType;
+} // namespace amx
+} // namespace x86
+} // namespace mlir
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/X86/X86.h.inc"
 
diff --git a/mlir/lib/CAPI/Dialect/AMX.cpp b/mlir/lib/CAPI/Dialect/AMX.cpp
deleted file mode 100644
index ed208c9b4b725..0000000000000
--- a/mlir/lib/CAPI/Dialect/AMX.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===- AMX.cpp - C Interface for AMX dialect ------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir-c/Dialect/AMX.h"
-#include "mlir/CAPI/Registration.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMX, amx, mlir::amx::AMXDialect)
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index 46b83b3d4f79f..551f5a5a3df77 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -26,15 +26,6 @@ add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU
   MLIRAMDGPUTransforms
 )
 
-add_mlir_upstream_c_api_library(MLIRCAPIAMX
-  AMX.cpp
-
-  PARTIAL_SOURCES_INTENDED
-  LINK_LIBS PUBLIC
-  MLIRCAPIIR
-  MLIRAMXDialect
-)
-
 add_mlir_upstream_c_api_library(MLIRCAPIArith
   Arith.cpp
   ArithPasses.cpp
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
index 2d4b2b6e9283c..2ed864c519cb6 100644
--- a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_conversion_library(MLIRVectorToAMX
   MLIRConversionPassIncGen
 
   LINK_LIBS PUBLIC
-  MLIRAMXDialect
   MLIRAffineUtils
   MLIRArithDialect
   MLIRLinalgUtils
@@ -16,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToAMX
   MLIRSCFDialect
   MLIRTransforms
   MLIRVectorDialect
+  MLIRX86Dialect
   )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 245a3efe98ecc..bce67b3e4748b 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -1,4 +1,4 @@
-//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
+//===- VectorToAMX.cpp - Convert vector to X86 dialect AMX ops --*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,7 +8,6 @@
 
 #include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
 
-#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -16,6 +15,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86/X86Dialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -197,7 +197,7 @@ static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
 static Operation *
 loadStoreFromTransfer(PatternRewriter &rewriter,
                       VectorTransferOpInterface xferOp, bool isPacked,
-                      TypedValue<amx::TileType> tileToStore = nullptr) {
+                      TypedValue<x86::amx::TileType> tileToStore = nullptr) {
   if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
     return nullptr;
   if (xferOp.hasOutOfBoundsDim() ||
@@ -267,18 +267,18 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
     src = collapseLastDim(rewriter, src);
   int64_t rows = vecShape[0];
   int64_t cols = llvm::product_of(vecShape.drop_front());
-  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+  auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
 
   Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
 
   Operation *amxTileOp = nullptr;
   if (isa<vector::TransferReadOp>(xferOp)) {
-    amxTileOp =
-        amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+    amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src,
+                                             tileIndicides);
   } else if (isa<vector::TransferWriteOp>(xferOp)) {
-    amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
-                                         tileToStore);
+    amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+                                              tileToStore);
   } else {
     llvm_unreachable("unsupported vector transfer op");
   }
@@ -289,10 +289,10 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
 /// Attempt to create an AMX tile load operation equivalent to the given
 /// vector transfer `readOp`.
 /// Returns loaded AMX tile if successful.
-static FailureOr<TypedValue<amx::TileType>>
+static FailureOr<TypedValue<x86::amx::TileType>>
 loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
                  bool isPacked) {
-  amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
+  x86::amx::TileLoadOp loadOp = dyn_cast_if_present<x86::amx::TileLoadOp>(
       loadStoreFromTransfer(rewriter, readOp, isPacked));
   if (!loadOp)
     return failure();
@@ -301,16 +301,16 @@ loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
 
 /// Attempt to create an AMX tile store operation equivalent to the given
 /// vector transfer `writeOp`.
-static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
-                                       vector::TransferWriteOp writeOp,
-                                       TypedValue<amx::TileType> tileToStore) {
+static LogicalResult
+storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp,
+                  TypedValue<x86::amx::TileType> tileToStore) {
   return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
                                        tileToStore));
 }
 
 /// Load vector values to an AMX tile.
-static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
-                                          TypedValue<VectorType> vec) {
+static TypedValue<x86::amx::TileType> loadTile(PatternRewriter &rewriter,
+                                               TypedValue<VectorType> vec) {
   Location loc = vec.getLoc();
 
   VectorType vecTy = vec.getType();
@@ -318,7 +318,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
 
   // Try to load tile directly from vector producer's buffer.
   auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
-  FailureOr<TypedValue<amx::TileType>> tile =
+  FailureOr<TypedValue<x86::amx::TileType>> tile =
       loadFromTransfer(rewriter, readOp, isPacked);
   if (succeeded(tile))
     return *tile;
@@ -337,25 +337,25 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
   ArrayRef<int64_t> shape = vecTy.getShape();
   int64_t rows = shape[0];
   int64_t cols = llvm::product_of(shape.drop_front());
-  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+  auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
 
-  return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
-                                 {zeroIndex, zeroIndex});
+  return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+                                      {zeroIndex, zeroIndex});
 }
 
 /// Store an AMX tile in a vector.
 static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
-                                        TypedValue<amx::TileType> tile) {
+                                        TypedValue<x86::amx::TileType> tile) {
   Location loc = tile.getLoc();
 
   // Transfer the tile to a vector through an intermediate buffer.
-  amx::TileType tileTy = tile.getType();
+  x86::amx::TileType tileTy = tile.getType();
   Value buf = memref::AllocaOp::create(
       rewriter, loc,
       MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
   Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> indices(2, zeroIndex);
-  amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+  x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
 
   auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
   return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
@@ -374,19 +374,21 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
     if (failed(validateOperands(rewriter, contractOp)))
       return failure();
 
-    TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
-    TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+    TypedValue<x86::amx::TileType> lhsTile =
+        loadTile(rewriter, contractOp.getLhs());
+    TypedValue<x86::amx::TileType> rhsTile =
+        loadTile(rewriter, contractOp.getRhs());
     auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
     assert(acc && "Invalid accumulator type");
-    TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+    TypedValue<x86::amx::TileType> accTile = loadTile(rewriter, acc);
 
-    TypedValue<amx::TileType> tileMul;
+    TypedValue<x86::amx::TileType> tileMul;
     if (acc.getType().getElementType().isFloat()) {
-      tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
-                                        lhsTile, rhsTile, accTile);
+      tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+                                             lhsTile, rhsTile, accTile);
     } else {
-      tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
-                                        lhsTile, rhsTile, accTile);
+      tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+                                             lhsTile, rhsTile, accTile);
     }
 
     // If the contraction result is only written back to memory, try to replace
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 4b1e72788becd..0d700ea65eb4e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -38,8 +38,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
-  MLIRAMXDialect
-  MLIRAMXTransforms
   MLIRX86Dialect
   MLIRX86Transforms
 )
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 19c42ed7e9ed8..4cc5704353382 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -10,8 +10,6 @@
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmNeon/Transforms.h"
@@ -51,8 +49,6 @@ struct ConvertVectorToLLVMPass
       registry.insert<arm_neon::ArmNeonDialect>();
     if (armSVE)
       registry.insert<arm_sve::ArmSVEDialect>();
-    if (amx)
-      registry.insert<amx::AMXDialect>();
     if (x86)
       registry.insert<x86::X86Dialect>();
   }
@@ -136,10 +132,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     configureArmSVELegalizeForExportTarget(target);
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
-  if (amx) {
-    configureAMXLegalizeForExportTarget(target);
-    populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
-  }
   if (x86) {
     configureX86LegalizeForExportTarget(target);
     populateX86LegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
deleted file mode 100644
index 9f57627c321fb..0000000000000
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-add_subdirectory(IR)
-add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
deleted file mode 100644
index d9c097c9a3c6f..0000000000000
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ /dev/null
@@ -1,318 +0,0 @@
-//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the AMX dialect and its operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/TypeUtilities.h"
-
-#include "llvm/ADT/TypeSwitch.h"
-
-using namespace mlir;
-
-#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
-
-#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
-
-void amx::AMXDialect::initialize() {
-  addTypes<
-#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
-      >();
-
-  addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/AMX/AMX.cpp.inc"
-      >();
-}
-
-/// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
-  const unsigned kMaxRows = 16;
-  const unsigned kBitsPerRow = 64 * 8;
-  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
-  if (tp.getDimSize(0) > kMaxRows)
-    return op->emitOpError("bad row height: ") << tp.getDimSize(0);
-  if (col > kBitsPerRow || col & 0x1f)
-    return op->emitOpError("bad column width: ") << (col >> 3);
-  return success();
-}
-
-/// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
-                                     amx::TileType btp, amx::TileType ctp,
-                                     unsigned scale) {
-  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
-  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
-  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
-  if (cm != am || cn != bn || ak != bk)
-    return op->emitOpError("bad mult shape: ")
-           << cm << " x " << cn << " x " << ak;
-  return success();
-}
-
-/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
-/// dimension directly translates into the number of rows of the tiles.
-/// The second dimensions needs to be scaled by the number of bytes.
-static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
-                                       RewriterBase &rewriter) {
-  Type llvmInt16Type = rewriter.getIntegerType(16);
-  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
-  assert(llvm::isPowerOf2_64(width) && width >= 8);
-  unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
-  return SmallVector<Value>{
-      LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
-      LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
-}
-
-/// Returns stride expressed in number of bytes for the given `elementStride`
-/// stride encoded in number of elements of the type `mType`.
-static Value computeStrideInBytes(Location loc, MemRefType mType,
-                                  Value elementStride, RewriterBase &rewriter) {
-  Type llvmInt64Type = rewriter.getIntegerType(64);
-  unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
-  auto attr = rewriter.getI64IntegerAttr(bytes);
-  Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
-  return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
-      .getResult();
-}
-
-/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
-/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value inferStride(Location loc, MemRefType mType, Value base,
-                         RewriterBase &rewriter) {
-  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
-  int64_t preLast = mType.getRank() - 2;
-  Type llvmInt64Type = rewriter.getIntegerType(64);
-  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
-  assert(llvm::isPowerOf2_64(width) && width >= 8);
-  unsigned bytes = width >> 3;
-  auto [strides, offset] = mType.getStridesAndOffset();
-  if (strides[preLast] == ShapedType::kDynamic) {
-    // Dynamic stride needs code to compute the stride at runtime.
-    MemRefDescriptor memrefDescriptor(base);
-    return computeStrideInBytes(
-        loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
-  }
-  // Use direct constant for static stride.
-  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
-  return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
-      .getResult();
-}
-
-LogicalResult amx::TileZeroOp::verify() {
-  return verifyTileSize(*this, getTileType());
-}
-
-SmallVector<Value>
-amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                      const LLVMTypeConverter &typeConverter,
-                                      RewriterBase &rewriter) {
-  return getTileSizes(getLoc(), getTileType(), rewriter);
-}
-
-template <typename OpTy,
-          typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
-                                      std::is_same_v<OpTy, amx::TileStoreOp>>>
-static LogicalResult tileTransferVerifier(OpTy op) {
-  MemRefType memrefTy = op.getMemRefType();
-  unsigned rank = memrefTy.getRank();
-  if (op.getIndices().size() != rank)
-    return op.emitOpError("requires ") << rank << " indices";
-
-  if (failed(verifyTileSize(op, op.getTileType())))
-    return failure();
-
-  // Validate basic buffer properties when the stride is implicit.
-  if (!op.getStride()) {
-    if (rank < 2)
-      return op.emitOpError("requires at least 2D memref");
-    SmallVector<int64_t> strides;
-    int64_t offset;
-    if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
-        strides.back() != 1)
-      return op.emitOpError("requires memref with unit innermost stride");
-  }
-
-  return success();
-}
-
-void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
-                            Value base, ValueRange indices) {
-  build(builder, state, res, base, indices, /*stride=*/nullptr);
-}
-
-LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
-
-SmallVector<Value>
-amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                      const LLVMTypeConverter &typeConverter,
-                                      RewriterBase &rewriter) {
-  auto loc = getLoc();
-  Adaptor adaptor(operands, *this);
-
-  SmallVector<Value> intrinsicOperands;
-  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
-  intrinsicOperands.push_back(
-      LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
-                                 adaptor.getBase(), adaptor.getIndices()));
-  if (Value stride = adaptor.getStride())
-    intrinsicOperands.push_back(
-        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
-  else
-    intrinsicOperands.push_back(
-        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
-
-  return intrinsicOperands;
-}
-
-void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
-                             Value base, ValueRange indices, Value val) {
-  build(builder, state, base, indices, val, /*stride=*/nullptr);
-}
-
-LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
-
-SmallVector<Value>
-amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                       const LLVMTypeConverter &typeConverter,
-                                       RewriterBase &rewriter) {
-  auto loc = getLoc();
-  Adaptor adaptor(operands, *this);
-
-  SmallVector<Value> intrinsicOperands;
-  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
-  intrinsicOperands.push_back(
-      LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
-                                 adaptor.getBase(), adaptor.getIndices()));
-  if (Value stride = adaptor.getStride())
-    intrinsicOperands.push_back(
-        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
-  else
-    intrinsicOperands.push_back(
-        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
-  intrinsicOperands.push_back(adaptor.getVal());
-
-  return intrinsicOperands;
-}
-
-LogicalResult amx::TileMulFOp::verify() {
-  amx::TileType aType = getLhsTileType();
-  amx::TileType bType = getRhsTileType();
-  amx::TileType cType = getTileType();
-  if (failed(verifyTileSize(*this, aType)) ||
-      failed(verifyTileSize(*this, bType)) ||
-      failed(verifyTileSize(*this, cType)) ||
-      failed(verifyMultShape(*this, aType, bType, cType, 1)))
-    return failure();
-  Type ta = aType.getElementType();
-  Type tb = bType.getElementType();
-  Type tc = cType.getElementType();
-  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
-    return emitOpError("unsupported type combination");
-  return success();
-}
-
-SmallVector<Value>
-amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                      const LLVMTypeConverter &typeConverter,
-                                      RewriterBase &rewriter) {
-  auto loc = getLoc();
-  Adaptor adaptor(operands, *this);
-
-  amx::TileType aType = getLhsTileType();
-  amx::TileType bType = getRhsTileType();
-  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
-  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
-
-  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
-                                          tsza[1],          adaptor.getAcc(),
-                                          adaptor.getLhs(), adaptor.getRhs()};
-
-  return intrinsicOperands;
-}
-
-LogicalResult amx::TileMulIOp::verify() {
-  amx::TileType aType = getLhsTileType();
-  amx::TileType bType = getRhsTileType();
-  amx::TileType cType = getTileType();
-  if (failed(verifyTileSize(*this, aType)) ||
-      failed(verifyTileSize(*this, bType)) ||
-      failed(verifyTileSize(*this, cType)) ||
-      failed(verifyMultShape(*this, aType, bType, cType, 2)))
-    return failure();
-  Type ta = aType.getElementType();
-  Type tb = bType.getElementType();
-  Type tc = cType.getElementType();
-  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
-    return emitOpError("unsupported type combination");
-  return success();
-}
-
-SmallVector<Value>
-amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                      const LLVMTypeConverter &typeConverter,
-                                      RewriterBase &rewriter) {
-  auto loc = getLoc();
-  Adaptor adaptor(operands, *this);
-
-  amx::TileType aType = getLhsTileType();
-  amx::TileType bType = getRhsTileType();
-  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
-  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
-
-  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
-                                          tsza[1],          adaptor.getAcc(),
-                                          adaptor.getLhs(), adaptor.getRhs()};
-
-  return intrinsicOperands;
-}
-
-Type amx::TileType::parse(AsmParser &parser) {
-  if (parser.parseLess())
-    return nullptr;
-
-  SmallVector<int64_t, 2> shape;
-  if (parser.parseDimensionList(shape, false, true))
-    return nullptr;
-
-  Type elementType;
-  if (parser.parseType(elementType))
-    return nullptr;
-
-  if (parser.parseGreater())
-    return nullptr;
-
-  return TileType::getChecked(
-      [&] { return parser.emitError(parser.getNameLoc()); }, shape,
-      elementType);
-}
-
-void amx::TileType::print(AsmPrinter &os) const {
-  os << "<";
-  os.printDimensionList(getShape());
-  os << 'x';
-  os.printType(getElementType());
-  os << '>';
-}
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/AMX.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
deleted file mode 100644
index b6e2759843d5e..0000000000000
--- a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-add_mlir_dialect_library(MLIRAMXDialect
-  AMXDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX
-
-  DEPENDS
-  MLIRAMXIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMCommonConversion
-  MLIRLLVMDialect
-  MLIRSideEffectInterfaces
-  )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
deleted file mode 100644
index e827bc475e930..0000000000000
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-add_mlir_dialect_library(MLIRAMXTransforms
-  LegalizeForLLVMExport.cpp
-
-  LINK_LIBS PUBLIC
-  MLIRAMXDialect
-  MLIRIR
-  MLIRLLVMCommonConversion
-  MLIRLLVMDialect
-  )
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
deleted file mode 100644
index 6483af222e91b..0000000000000
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ /dev/null
@@ -1,70 +0,0 @@
-//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMX/Transforms.h"
-
-#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/IR/PatternMatch.h"
-
-using namespace mlir;
-using namespace mlir::amx;
-
-namespace {
-
-/// Generic one-to-one conversion of simply mappable operations into calls
-/// to their respective LLVM intrinsics.
-struct AMXIntrinsicOpConversion
-    : public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
-  using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    const LLVMTypeConverter &typeConverter = *getTypeConverter();
-    return LLVM::detail::intrinsicRewrite(
-        op, rewriter.getStringAttr(op.getIntrinsicName()),
-        op.getIntrinsicOperands(operands, typeConverter, rewriter),
-        typeConverter, rewriter);
-  }
-};
-
-} // namespace
-
-void mlir::populateAMXLegalizeForLLVMExportPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<AMXIntrinsicOpConversion>(converter);
-  converter.addConversion([&](amx::TileType type) {
-    return LLVM::LLVMX86AMXType::get(&converter.getContext());
-  });
-}
-
-void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
-  target.addIllegalDialect<AMXDialect>();
-}
-
-namespace {
-/// Implement the interface to convert AMX to LLVM.
-struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
-  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
-
-  void populateConvertToLLVMConversionPatterns(
-      ConversionTarget &target, LLVMTypeConverter &typeConverter,
-      RewritePatternSet &patterns) const final {
-    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
-  }
-};
-} // namespace
-
-void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
-    dialect->addInterfaces<AMXToLLVMDialectInterface>();
-  });
-}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 65dada6ac4bfc..66f68c369f81f 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_subdirectory(Affine)
 add_subdirectory(AMDGPU)
-add_subdirectory(AMX)
 add_subdirectory(Arith)
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index e1714bdc8dc17..47ee5d272a890 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -11,10 +11,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/X86/X86Dialect.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
@@ -22,6 +28,11 @@ using namespace mlir;
 #include "mlir/Dialect/X86/X86Dialect.cpp.inc"
 
 void x86::X86Dialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/X86/X86Types.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/X86/X86.cpp.inc"
@@ -107,5 +118,279 @@ SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
                            typeConverter, rewriter)};
 }
 
+/// Verify that AMX supports the implied tile shape.
+static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp) {
+  const unsigned kMaxRows = 16;
+  const unsigned kBitsPerRow = 64 * 8;
+  unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
+  if (tp.getDimSize(0) > kMaxRows)
+    return op->emitOpError("bad row height: ") << tp.getDimSize(0);
+  if (col > kBitsPerRow || col & 0x1f)
+    return op->emitOpError("bad column width: ") << (col >> 3);
+  return success();
+}
+
+/// Verify that AMX supports the multiplication.
+static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp,
+                                     x86::amx::TileType btp,
+                                     x86::amx::TileType ctp, unsigned scale) {
+  unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
+  unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
+  unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
+  if (cm != am || cn != bn || ak != bk)
+    return op->emitOpError("bad mult shape: ")
+           << cm << " x " << cn << " x " << ak;
+  return success();
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, x86::amx::TileType tType,
+                                       RewriterBase &rewriter) {
+  Type llvmInt16Type = rewriter.getIntegerType(16);
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+  return SmallVector<Value>{
+      LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
+      LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
+}
+
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+                                  Value elementStride, RewriterBase &rewriter) {
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+  auto attr = rewriter.getI64IntegerAttr(bytes);
+  Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+  return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+      .getResult();
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+static Value inferStride(Location loc, MemRefType mType, Value base,
+                         RewriterBase &rewriter) {
+  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
+  int64_t preLast = mType.getRank() - 2;
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto [strides, offset] = mType.getStridesAndOffset();
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the stride at runtime.
+    MemRefDescriptor memrefDescriptor(base);
+    return computeStrideInBytes(
+        loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
+  }
+  // Use direct constant for static stride.
+  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
+  return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
+      .getResult();
+}
+
+LogicalResult x86::amx::TileZeroOp::verify() {
+  return verifyTileSize(*this, getTileType());
+}
+
+SmallVector<Value> x86::amx::TileZeroOp::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  return getTileSizes(getLoc(), getTileType(), rewriter);
+}
+
+template <typename OpTy, typename = std::enable_if_t<
+                             std::is_same_v<OpTy, x86::amx::TileLoadOp> ||
+                             std::is_same_v<OpTy, x86::amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+  MemRefType memrefTy = op.getMemRefType();
+  unsigned rank = memrefTy.getRank();
+  if (op.getIndices().size() != rank)
+    return op.emitOpError("requires ") << rank << " indices";
+
+  if (failed(verifyTileSize(op, op.getTileType())))
+    return failure();
+
+  // Validate basic buffer properties when the stride is implicit.
+  if (!op.getStride()) {
+    if (rank < 2)
+      return op.emitOpError("requires at least 2D memref");
+    SmallVector<int64_t> strides;
+    int64_t offset;
+    if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+        strides.back() != 1)
+      return op.emitOpError("requires memref with unit innermost stride");
+  }
+
+  return success();
+}
+
+void x86::amx::TileLoadOp::build(OpBuilder &builder, OperationState &state,
+                                 Type res, Value base, ValueRange indices) {
+  build(builder, state, res, base, indices, /*stride=*/nullptr);
+}
+
+LogicalResult x86::amx::TileLoadOp::verify() {
+  return tileTransferVerifier(*this);
+}
+
+SmallVector<Value> x86::amx::TileLoadOp::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  SmallVector<Value> intrinsicOperands;
+  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
+  intrinsicOperands.push_back(
+      LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
+                                 adaptor.getBase(), adaptor.getIndices()));
+  if (Value stride = adaptor.getStride())
+    intrinsicOperands.push_back(
+        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+  else
+    intrinsicOperands.push_back(
+        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+
+  return intrinsicOperands;
+}
+
+void x86::amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+                                  Value base, ValueRange indices, Value val) {
+  build(builder, state, base, indices, val, /*stride=*/nullptr);
+}
+
+LogicalResult x86::amx::TileStoreOp::verify() {
+  return tileTransferVerifier(*this);
+}
+
+SmallVector<Value> x86::amx::TileStoreOp::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  SmallVector<Value> intrinsicOperands;
+  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
+  intrinsicOperands.push_back(
+      LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
+                                 adaptor.getBase(), adaptor.getIndices()));
+  if (Value stride = adaptor.getStride())
+    intrinsicOperands.push_back(
+        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+  else
+    intrinsicOperands.push_back(
+        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+  intrinsicOperands.push_back(adaptor.getVal());
+
+  return intrinsicOperands;
+}
+
+LogicalResult x86::amx::TileMulFOp::verify() {
+  x86::amx::TileType aType = getLhsTileType();
+  x86::amx::TileType bType = getRhsTileType();
+  x86::amx::TileType cType = getTileType();
+  if (failed(verifyTileSize(*this, aType)) ||
+      failed(verifyTileSize(*this, bType)) ||
+      failed(verifyTileSize(*this, cType)) ||
+      failed(verifyMultShape(*this, aType, bType, cType, 1)))
+    return failure();
+  Type ta = aType.getElementType();
+  Type tb = bType.getElementType();
+  Type tc = cType.getElementType();
+  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
+    return emitOpError("unsupported type combination");
+  return success();
+}
+
+SmallVector<Value> x86::amx::TileMulFOp::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  x86::amx::TileType aType = getLhsTileType();
+  x86::amx::TileType bType = getRhsTileType();
+  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
+  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
+
+  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
+                                          tsza[1],          adaptor.getAcc(),
+                                          adaptor.getLhs(), adaptor.getRhs()};
+
+  return intrinsicOperands;
+}
+
+LogicalResult x86::amx::TileMulIOp::verify() {
+  x86::amx::TileType aType = getLhsTileType();
+  x86::amx::TileType bType = getRhsTileType();
+  x86::amx::TileType cType = getTileType();
+  if (failed(verifyTileSize(*this, aType)) ||
+      failed(verifyTileSize(*this, bType)) ||
+      failed(verifyTileSize(*this, cType)) ||
+      failed(verifyMultShape(*this, aType, bType, cType, 2)))
+    return failure();
+  Type ta = aType.getElementType();
+  Type tb = bType.getElementType();
+  Type tc = cType.getElementType();
+  if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
+    return emitOpError("unsupported type combination");
+  return success();
+}
+
+SmallVector<Value> x86::amx::TileMulIOp::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  x86::amx::TileType aType = getLhsTileType();
+  x86::amx::TileType bType = getRhsTileType();
+  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
+  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
+
+  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
+                                          tsza[1],          adaptor.getAcc(),
+                                          adaptor.getLhs(), adaptor.getRhs()};
+
+  return intrinsicOperands;
+}
+
+Type x86::amx::TileType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return nullptr;
+
+  SmallVector<int64_t, 2> shape;
+  if (parser.parseDimensionList(shape, false, true))
+    return nullptr;
+
+  Type elementType;
+  if (parser.parseType(elementType))
+    return nullptr;
+
+  if (parser.parseGreater())
+    return nullptr;
+
+  return AMXTileType::getChecked(
+      [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+      elementType);
+}
+
+void x86::amx::TileType::print(AsmPrinter &os) const {
+  os << "<";
+  os.printDimensionList(getShape());
+  os << 'x';
+  os.printType(getElementType());
+  os << '>';
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/X86/X86.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/X86/X86Types.cpp.inc"
diff --git a/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp
index c07559dc295fa..8907b5f482e9c 100644
--- a/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/X86/Transforms.h"
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/X86/X86Dialect.h"
@@ -39,10 +40,32 @@ struct X86IntrinsicOpConversion
 
 /// Populate the given list with patterns that convert from X86 to LLVM.
 void mlir::populateX86LegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<X86IntrinsicOpConversion>(converter);
+  converter.addConversion([&](x86::amx::TileType type) {
+    return LLVM::LLVMX86AMXType::get(&converter.getContext());
+  });
 }
 
 void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addIllegalDialect<X86Dialect>();
 }
+
+namespace {
+/// Implement the interface to convert X86 to LLVM.
+struct X86ToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateX86LegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertX86ToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, x86::X86Dialect *dialect) {
+    dialect->addInterfaces<X86ToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 10944f72aa3c9..ea5698f39c0b0 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -14,7 +14,6 @@
 #include "mlir/InitAllDialects.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -111,7 +110,6 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
   registry.insert<acc::OpenACCDialect,
                   affine::AffineDialect,
                   amdgpu::AMDGPUDialect,
-                  amx::AMXDialect,
                   arith::ArithDialect,
                   arm_neon::ArmNeonDialect,
                   arm_sme::ArmSMEDialect,
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 14b583484ba38..27a89ef8712da 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -32,7 +32,6 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -57,6 +56,7 @@
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
+#include "mlir/Dialect/X86/Transforms.h"
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -90,10 +90,10 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
-  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
+  registerConvertX86ToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index dfa6a9943543b..6078c08002476 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -32,8 +32,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
       "The GPU compilation format used by the tests.")
   set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
       "Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
-  option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_X86_TESTS "Run X86 tests.")
+  option(MLIR_RUN_X86_AMX_TESTS "Run X86 AMX tests.")
   option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
   option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
   option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
@@ -76,7 +76,7 @@ llvm_canonicalize_cmake_booleans(
   MLIR_ENABLE_SPIRV_CPU_RUNNER
   MLIR_ENABLE_VULKAN_RUNNER
   MLIR_INCLUDE_INTEGRATION_TESTS
-  MLIR_RUN_AMX_TESTS
+  MLIR_RUN_X86_AMX_TESTS
   MLIR_RUN_CUDA_TENSOR_CORE_TESTS
   MLIR_RUN_X86_TESTS
   MLIR_RUN_ARM_SVE_TESTS
diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
index 4fb88dd165126..745cb088ef3f2 100644
--- a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -34,27 +34,27 @@ func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
 // CHECK:       vector.transfer_write %[[A]], %[[A_BUF]]
 // CHECK:       %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
 // CHECK-SAME:    {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
-// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
+// CHECK:       %[[A_TILE:.+]] = x86.amx.tile_load %[[A_BUF_2D]]
 
 /// Load B vector into an AMX tile
 // CHECK:       %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
 // CHECK:       vector.transfer_write %[[B]], %[[B_BUF]]
 // CHECK:       %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
 // CHECK-SAME:    {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
-// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
+// CHECK:       %[[B_TILE:.+]] = x86.amx.tile_load %[[B_BUF_2D]]
 
 /// Load C vector into an AMX tile
 // CHECK:       %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
 // CHECK:       vector.transfer_write %[[C]], %[[C_BUF]]
-// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
+// CHECK:       %[[C_TILE:.+]] = x86.amx.tile_load %[[C_BUF]]
 
 /// Perform tile multiplication
-// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK:       %[[RES:.+]] = x86.amx.tile_mulf
 // CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
 
 /// Load the result back into a vector
 // CHECK:       %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
-// CHECK:       amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
+// CHECK:       x86.amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
 // CHECK:       %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
 
 // CHECK:       return %[[RES_VEC]]
@@ -75,9 +75,9 @@ func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
 }
 
 // CHECK-LABEL: @contract_vnni_bf16(
-// CHECK-COUNT-3: amx.tile_load
-// CHECK: amx.tile_mulf
-// CHECK: amx.tile_store
+// CHECK-COUNT-3: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store
 
 // -----
 
@@ -95,9 +95,9 @@ func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
 }
 
 // CHECK-LABEL: @contract_vnni_i8(
-// CHECK-COUNT-3: amx.tile_load
-// CHECK: amx.tile_muli
-// CHECK: amx.tile_store
+// CHECK-COUNT-3: x86.amx.tile_load
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store
 
 // -----
 
@@ -115,9 +115,9 @@ func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4x
 }
 
 // CHECK-LABEL: @contract_shuffled_iterators(
-// CHECK-COUNT-3: amx.tile_load
-// CHECK: amx.tile_muli
-// CHECK: amx.tile_store
+// CHECK-COUNT-3: x86.amx.tile_load
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
index 8fab4cf1f7ed1..120f13bd5d876 100644
--- a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -38,7 +38,7 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
 // CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
 // CHECK:       %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
 // CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
-// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK:       %[[A_TILE:.+]] = x86.amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
 // CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
 // CHECK-NOT:   vector.transfer_read %[[A]]
 
@@ -47,25 +47,25 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
 // CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
 // CHECK:       %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
 // CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
-// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK:       %[[B_TILE:.+]] = x86.amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
 // CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
 // CHECK-NOT:   vector.transfer_read %[[B]]
 
 /// Load C into an AMX tile
 // CHECK:       %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
 // CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
-// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK:       %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
 // CHECK-SAME:    {{\[}}%[[C0]], %[[C0]]{{\]}}
 // CHECK-NOT:   vector.transfer_read %[[C]]
 
 /// Perform tile multiplication
-// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK:       %[[RES:.+]] = x86.amx.tile_mulf
 // CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
 
 /// Store the result back
 // CHECK:       %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
 // CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
-// CHECK:       amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK:       x86.amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
 // CHECK-NOT:   vector.transfer_write{{.*}}%[[C]]
 
 // -----
@@ -130,17 +130,17 @@ func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
 
 /// Load to AMX tile directly from buffer.
 // CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
-// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
 
 /// Vector read remains to load data for the other non-AMX consumer.
 // CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
 
 /// Contraction uses the directly loaded tile.
-// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
+// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf{{.*}}%[[C_TILE]]
 
 /// Consumer uses original C value and the updated one after contraction.
 // CHECK: %[[RES_BUF:.+]] = memref.alloca
-// CHECK: amx.tile_store %[[RES_BUF]]
+// CHECK: x86.amx.tile_store %[[RES_BUF]]
 // CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
 // CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]]
 
@@ -168,7 +168,7 @@ func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
 
 // CHECK-LABEL: @negative_contract_multiple_users(
 // CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>
-// CHECK:     %[[TILE_MUL:.+]] = amx.tile_mulf
+// CHECK:     %[[TILE_MUL:.+]] = x86.amx.tile_mulf
 // CHECK: vector.transfer_write{{.*}}%[[C]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
index 8070aee19f946..e457e318b0780 100644
--- a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
@@ -18,7 +18,6 @@
 
 // CHECK: builtin.module(
 // CHECK-SAME: convert-vector-to-llvm{
-// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
 // CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
 // CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
 // CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
deleted file mode 100644
index 5de9b3f82a868..0000000000000
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ /dev/null
@@ -1,158 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-
-func.func @tile_row_height() {
-  // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : !amx.tile<17x16xbf16>
-  return
-}
-
-// -----
-
-func.func @tile_col_width() {
-  // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : !amx.tile<16x65xi8>
-  return
-}
-
-// -----
-
-func.func @tile_element_type() {
-  // expected-error at +1 {{failed to verify 'elementType'}}
-  %0 = amx.tile_zero : !amx.tile<8x8xi16>
-  return
-}
-
-// -----
-
-func.func @tile_rank() {
-  // expected-error at +1 {{'amx.tile_zero' op result #0 must be tile of}}
-  %0 = amx.tile_zero : !amx.tile<32xi8>
-  return
-}
-
-// -----
-
-func.func @tile_col_4_byte_multiple() {
-  // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : !amx.tile<16x5xi8>
-  return
-}
-
-// -----
-
-func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
-  return
-}
-
-// -----
-
-func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_store' op bad column width: 68}}
-  amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
-  return
-}
-
-// -----
-
-func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_store' op requires 2 indices}}
-  amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @load_base_rank(%arg0: memref<?xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_load' op requires at least 2D memref}}
-  %1 = amx.tile_load %arg0[%0] : memref<?xf32> into !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !amx.tile<16x16xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_store' op requires at least 2D memref}}
-  amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_load' op requires memref with unit innermost stride}}
-  %1 = amx.tile_load %arg0[%0, %0]
-    : memref<?x?xf32, strided<[?, ?]>> into !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
-    %arg1: !amx.tile<16x16xf32>) {
-  %0 = arith.constant 0 : index
-  // expected-error at +1 {{'amx.tile_store' op requires memref with unit innermost stride}}
-  amx.tile_store %arg0[%0, %0], %arg1
-    : memref<?x?xf32, strided<[?, ?]>>, !amx.tile<16x16xf32>
-  return
-}
-
-// -----
-
-func.func @mulf_shape() {
-  %0 = amx.tile_zero : !amx.tile<8x8xbf16>
-  %1 = amx.tile_zero : !amx.tile<8x8xbf16>
-  %2 = amx.tile_zero : !amx.tile<4x4xf32>
-  // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
-  return
-}
-
-// -----
-
-func.func @mulf_type_combination() {
-  %0 = amx.tile_zero : !amx.tile<8x8xbf16>
-  %1 = amx.tile_zero : !amx.tile<4x8xf16>
-  %2 = amx.tile_zero : !amx.tile<8x4xf32>
-  // expected-error at +1 {{'amx.tile_mulf' op unsupported type combination}}
-  %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32>
-  return
-}
-
-// -----
-
-func.func @muli_shape() {
-  %0 = amx.tile_zero : !amx.tile<8x8xi8>
-  %1 = amx.tile_zero : !amx.tile<8x8xi8>
-  %2 = amx.tile_zero : !amx.tile<4x4xi32>
-  // expected-error at +1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
-  %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32>
-  return
-}
-
-// -----
-
-func.func @muli_type_combination() {
-  %0 = amx.tile_zero : !amx.tile<8x16xi8>
-  %1 = amx.tile_zero : !amx.tile<8x16xi32>
-  %2 = amx.tile_zero : !amx.tile<2x2xi32>
-  // expected-error at +1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
-  %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32>
-  return
-}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
deleted file mode 100644
index 3d0f276df6a26..0000000000000
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ /dev/null
@@ -1,77 +0,0 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: tloadstore
-// CHECK:      %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
-// CHECK-SAME:   memref<?xbf16> into !amx.tile<16x32xbf16>
-// CHECK:      %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
-// CHECK-SAME:   memref<?x?xbf16> into !amx.tile<16x32xbf16>
-// CHECK:      %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
-// CHECK-SAME:   memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
-// CHECK:      amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
-// CHECK-SAME:   memref<?xbf16>, !amx.tile<16x32xbf16>
-// CHECK:      amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
-// CHECK-SAME:   memref<?x?xbf16>, !amx.tile<16x32xbf16>
-// CHECK:      amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
-// CHECK-SAME:   memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
-func.func @tloadstore(%stride: index,
-    %arg0: memref<?xbf16>,
-    %arg1: memref<?x?xbf16>,
-    %arg2: memref<?x?xbf16, strided<[64, 1]>>) {
-  %0 = arith.constant 0 : index
-  %c64 = arith.constant 64 : index
-  %1 = amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !amx.tile<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !amx.tile<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
-  amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
-  return
-}
-
-// CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
-func.func @tzero(%arg0: memref<?x?xbf16>) {
-  %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
-  return
-}
-
-// CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
-func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
-  %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
-  return
-}
-
-// CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
-// Verify the parsing/printing of the sign-extension annotation.
-// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
-// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
-// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
-func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
-  %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
-  // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  return
-}
diff --git a/mlir/test/Dialect/AMX/side-effects.mlir b/mlir/test/Dialect/AMX/side-effects.mlir
deleted file mode 100644
index 22c76d98c6996..0000000000000
--- a/mlir/test/Dialect/AMX/side-effects.mlir
+++ /dev/null
@@ -1,32 +0,0 @@
-// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s
-
-// With inclusion of memory side-effects, it is expected CSE not to fold multiple 
-// "tileload" and "tilezero".
-// CHECK-LABEL: do_not_fold_tiles(
-// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
-// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
-func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
-  %c1 = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %c2 = arith.constant 2 : index
-  %c16 = arith.constant 16 : index
-  %alloca = memref.alloca() : memref<16x32xf32>
-  %0 = amx.tile_zero : !amx.tile<16x16xf32>
-  %1 = amx.tile_zero : !amx.tile<16x16xf32>
-  %2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
-    %3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
-    %4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
-    %5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
-    %6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
-    %7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-    %8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-    scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
-  }
-  amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32>
-  amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32>
-  return %alloca : memref<16x32xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index bb868afe08cbf..6d6ebe0b5e60e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -635,10 +635,10 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 // -----
 
-func.func @invalid_type_matmul(%arg0 : !amx.tile<16x16xbf16>)
+func.func @invalid_type_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
 {
-  // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
-  %0 = linalg.matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
+  %0 = linalg.matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
   return
 }
 
@@ -1582,10 +1582,10 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 
 // -----
 
-func.func @invalid_type_batch_matmul(%arg0 : !amx.tile<16x16xbf16>)
+func.func @invalid_type_batch_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
 {
-  // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
-  %0 = linalg.batch_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
+  %0 = linalg.batch_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
   return
 }
 
@@ -1790,10 +1790,10 @@ func.func @invalid_C_map_result_dim(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>
 
 // -----
 
-func.func @batch_reduce_matmul_invalid_type(%arg0 : !amx.tile<16x16xbf16>)
+func.func @batch_reduce_matmul_invalid_type(%arg0 : !x86.amx.tile<16x16xbf16>)
 {
-  // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
-  %0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
+  %0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
   return
 }
 
diff --git a/mlir/test/Dialect/X86/AMX/invalid.mlir b/mlir/test/Dialect/X86/AMX/invalid.mlir
new file mode 100644
index 0000000000000..25033090808b4
--- /dev/null
+++ b/mlir/test/Dialect/X86/AMX/invalid.mlir
@@ -0,0 +1,158 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @tile_row_height() {
+  // expected-error at +1 {{'x86.amx.tile_zero' op bad row height: 17}}
+  %0 = x86.amx.tile_zero : !x86.amx.tile<17x16xbf16>
+  return
+}
+
+// -----
+
+func.func @tile_col_width() {
+  // expected-error at +1 {{'x86.amx.tile_zero' op bad column width: 65}}
+  %0 = x86.amx.tile_zero : !x86.amx.tile<16x65xi8>
+  return
+}
+
+// -----
+
+func.func @tile_element_type() {
+  // expected-error at +1 {{failed to verify 'elementType'}}
+  %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi16>
+  return
+}
+
+// -----
+
+func.func @tile_rank() {
+  // expected-error at +1 {{'x86.amx.tile_zero' op result #0 must be tile of}}
+  %0 = x86.amx.tile_zero : !x86.amx.tile<32xi8>
+  return
+}
+
+// -----
+
+func.func @tile_col_4_byte_multiple() {
+  // expected-error at +1 {{'x86.amx.tile_zero' op bad column width: 5}}
+  %0 = x86.amx.tile_zero : !x86.amx.tile<16x5xi8>
+  return
+}
+
+// -----
+
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_load' op bad column width: 68}}
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x17xf32>
+  return
+}
+
+// -----
+
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x17xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_store' op bad column width: 68}}
+  x86.amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x17xf32>
+  return
+}
+
+// -----
+
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_load' op requires 2 indices}}
+  %1 = x86.amx.tile_load %arg0[%0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_store' op requires 2 indices}}
+  x86.amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @load_base_rank(%arg0: memref<?xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_load' op requires at least 2D memref}}
+  %1 = x86.amx.tile_load %arg0[%0] : memref<?xf32> into !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_store' op requires at least 2D memref}}
+  x86.amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_load' op requires memref with unit innermost stride}}
+  %1 = x86.amx.tile_load %arg0[%0, %0]
+    : memref<?x?xf32, strided<[?, ?]>> into !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
+    %arg1: !x86.amx.tile<16x16xf32>) {
+  %0 = arith.constant 0 : index
+  // expected-error at +1 {{'x86.amx.tile_store' op requires memref with unit innermost stride}}
+  x86.amx.tile_store %arg0[%0, %0], %arg1
+    : memref<?x?xf32, strided<[?, ?]>>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// -----
+
+func.func @mulf_shape() {
+  %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
+  %2 = x86.amx.tile_zero : !x86.amx.tile<4x4xf32>
+  // expected-error at +1 {{'x86.amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
+  %3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x4xf32>
+  return
+}
+
+// -----
+
+func.func @mulf_type_combination() {
+  %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<4x8xf16>
+  %2 = x86.amx.tile_zero : !x86.amx.tile<8x4xf32>
+  // expected-error at +1 {{'x86.amx.tile_mulf' op unsupported type combination}}
+  %3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x8xf16>, !x86.amx.tile<8x4xf32>
+  return
+}
+
+// -----
+
+func.func @muli_shape() {
+  %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
+  %2 = x86.amx.tile_zero : !x86.amx.tile<4x4xi32>
+  // expected-error at +1 {{'x86.amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
+  %3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x8xi8>, !x86.amx.tile<8x8xi8>, !x86.amx.tile<4x4xi32>
+  return
+}
+
+// -----
+
+func.func @muli_type_combination() {
+  %0 = x86.amx.tile_zero : !x86.amx.tile<8x16xi8>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<8x16xi32>
+  %2 = x86.amx.tile_zero : !x86.amx.tile<2x2xi32>
+  // expected-error at +1 {{'x86.amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
+  %3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x16xi8>, !x86.amx.tile<8x16xi32>, !x86.amx.tile<2x2xi32>
+  return
+}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
similarity index 64%
rename from mlir/test/Dialect/AMX/legalize-for-llvm.mlir
rename to mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
index a109f42e9dea3..eb12e20b699b3 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: muli(
 // CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
@@ -14,17 +14,17 @@
 // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : !amx.tile<16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
+  %5 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
+  %6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
+  %7 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
   return
 }
 
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
 // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : !amx.tile<16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x32xbf16>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : !amx.tile<16x32xf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x32xf16>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -84,12 +84,12 @@ func.func @strides_implicit(%arg0: memref<16x32xi8>,
     %arg1: memref<32x32xbf16, strided<[64, 1]>>,
     %arg2: memref<16x32xf32, strided<[?, 1]>>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32>
-  amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8>
-  amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !x86.amx.tile<16x32xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
+  %3 = x86.amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !x86.amx.tile<16x32xi8>
+  x86.amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
+  x86.amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -123,11 +123,11 @@ func.func @strides_explicit(%stride: index,
     %arg2: memref<32x32xf32, strided<[64, 1]>>) {
   %0 = arith.constant 0 : index
   %c64 = arith.constant 64 : index
-  %1 = amx.tile_load %arg0[%0], %stride : memref<?xi8> into !amx.tile<16x32xi8>
-  %2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32>
-  amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !amx.tile<16x32xi8>
-  amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32>
+  %1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xi8> into !x86.amx.tile<16x32xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !x86.amx.tile<16x32xbf16>
+  %3 = x86.amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !x86.amx.tile<16x32xi8>
+  x86.amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !x86.amx.tile<16x32xbf16>
+  x86.amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !x86.amx.tile<16x16xf32>
   return
 }
diff --git a/mlir/test/Dialect/X86/AMX/roundtrip.mlir b/mlir/test/Dialect/X86/AMX/roundtrip.mlir
new file mode 100644
index 0000000000000..300c3aa054a70
--- /dev/null
+++ b/mlir/test/Dialect/X86/AMX/roundtrip.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: tloadstore
+// CHECK:      %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
+// CHECK-SAME:   memref<?xbf16> into !x86.amx.tile<16x32xbf16>
+// CHECK:      %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
+// CHECK-SAME:   memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+// CHECK:      %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
+// CHECK-SAME:   memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
+// CHECK:      x86.amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
+// CHECK-SAME:   memref<?xbf16>, !x86.amx.tile<16x32xbf16>
+// CHECK:      x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
+// CHECK-SAME:   memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
+// CHECK:      x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
+// CHECK-SAME:   memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
+func.func @tloadstore(%stride: index,
+    %arg0: memref<?xbf16>,
+    %arg1: memref<?x?xbf16>,
+    %arg2: memref<?x?xbf16, strided<[64, 1]>>) {
+  %0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !x86.amx.tile<16x32xbf16>
+  %2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+  %3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
+  x86.amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !x86.amx.tile<16x32xbf16>
+  x86.amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
+  x86.amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
+  return
+}
+
+// CHECK-LABEL: tzero
+// CHECK: x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
+// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
+func.func @tzero(%arg0: memref<?x?xbf16>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
+  x86.amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
+  return
+}
+
+// CHECK-LABEL: tmulf
+// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = x86.amx.tile_mulf %[[x]], %[[x]], %[[z]] : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %3 = x86.amx.tile_mulf %1, %1, %2 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: tmuli
+// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = x86.amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
+// Verify the parsing/printing of the sign-extension annotation.
+// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
+// CHECK: x86.amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
+// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
+func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  %3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
+  // Verify the various `zext` combinations.
+  %5 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  %6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  %7 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  return
+}
diff --git a/mlir/test/Dialect/X86/AMX/side-effects.mlir b/mlir/test/Dialect/X86/AMX/side-effects.mlir
new file mode 100644
index 0000000000000..fa475f34068e5
--- /dev/null
+++ b/mlir/test/Dialect/X86/AMX/side-effects.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-x86" | FileCheck %s
+
+// With inclusion of memory side-effects, it is expected CSE not to fold multiple 
+// "tileload" and "tilezero".
+// CHECK-LABEL: do_not_fold_tiles(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c16 = arith.constant 16 : index
+  %alloca = memref.alloca() : memref<16x32xf32>
+  %0 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+  %2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+    %3 = x86.amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
+    %4 = x86.amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
+    %5 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
+    %6 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
+    %7 = x86.amx.tile_mulf %3, %5, %arg3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+    %8 = x86.amx.tile_mulf %4, %6, %arg4 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
+    scf.yield %7, %8 : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+  }
+  x86.amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
+  return %alloca : memref<16x32xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg
similarity index 91%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg
index 70b4b66f4378d..df9057d8933c8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg
@@ -1,7 +1,7 @@
 import sys
 
 # AMX tests must be enabled via build flag.
-if not config.mlir_run_amx_tests:
+if not config.mlir_run_x86_amx_tests:
     config.unsupported = True
 
 # No JIT on win32.
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir
similarity index 95%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir
index 8014bb7d2dcce..bd67a50cffb27 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
 // RUN: -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \
+// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" \
 // RUN:  -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
@@ -14,11 +14,11 @@ func.func @kernel(%arg0: memref<16x32xbf16>,
              %arg1: memref<16x32xbf16>,
              %arg2: memref<16x16xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
-  %3 = amx.tile_zero : vector<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
+  %3 = x86.amx.tile_zero : vector<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir
similarity index 74%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir
index 5f7250f4d4ccb..f1ff2bdf902f7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x4xbf16>,
               %arg1: memref<2x4xbf16>,
               %arg2: memref<2x2xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
-  %3 = amx.tile_zero : vector<2x2xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %3 = x86.amx.tile_zero : vector<2x2xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
   return
 }
 
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x4xbf16>,
               %arg1: memref<2x4xbf16>,
               %arg2: memref<2x2xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16>  into vector<2x4xbf16>
+  %3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir
similarity index 83%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir
index 5c0618c2e5e54..c572cff40f283 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
@@ -21,11 +21,11 @@ func.func @kernel1(%arg0: memref<16x16xi8>,
               %arg1: memref<4x16xi8>,
               %arg2: memref<16x4xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
-  %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
+  %3 = x86.amx.tile_zero : vector<16x4xi32>
+  %4 = x86.amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
 
@@ -33,11 +33,11 @@ func.func @kernel2(%arg0: memref<16x16xi8>,
               %arg1: memref<4x16xi8>,
               %arg2: memref<16x4xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
-  %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
+  %3 = x86.amx.tile_zero : vector<16x4xi32>
+  %4 = x86.amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
 
@@ -45,11 +45,11 @@ func.func @kernel3(%arg0: memref<16x16xi8>,
               %arg1: memref<4x16xi8>,
               %arg2: memref<16x4xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
-  %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
+  %3 = x86.amx.tile_zero : vector<16x4xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
 
@@ -57,11 +57,11 @@ func.func @kernel4(%arg0: memref<16x16xi8>,
               %arg1: memref<4x16xi8>,
               %arg2: memref<16x4xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
-  %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
+  %3 = x86.amx.tile_zero : vector<16x4xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir
similarity index 95%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir
index a0076db6660d7..7208389f4cbf3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
 // RUN: -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN: -convert-scf-to-cf \
-// RUN:  -convert-vector-to-llvm="enable-amx" \
+// RUN:  -convert-vector-to-llvm="enable-x86" \
 // RUN:  -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
@@ -15,11 +15,11 @@ func.func @kernel(%arg0: memref<16x64xi8>,
              %arg1: memref<16x64xi8>,
              %arg2: memref<16x16xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x64xi8>  into vector<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<16x64xi8>  into vector<16x64xi8>
-  %3 = amx.tile_zero : vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x64xi8>  into vector<16x64xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x64xi8>  into vector<16x64xi8>
+  %3 = x86.amx.tile_zero : vector<16x16xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir
similarity index 74%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir
index 7b14df8dbd859..cd0b84c3a1886 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x8xi8>,
               %arg1: memref<2x8xi8>,
               %arg2: memref<2x2xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
-  %3 = amx.tile_zero : vector<2x2xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %3 = x86.amx.tile_zero : vector<2x2xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }
 
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x8xi8>,
               %arg1: memref<2x8xi8>,
               %arg2: memref<2x2xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
+  %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
+  %3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
+  %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir
similarity index 94%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir
index e35c555f0a85c..e6676c4411246 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
@@ -25,8 +25,8 @@ func.func @kernel(%arg0: memref<4x32xf32>) {
   %c32 = arith.constant 32 : index
   scf.for %i = %c0 to %c4 step %c2 {
     scf.for %j = %c0 to %c32 step %c16 {
-      %0 = amx.tile_zero : vector<2x16xf32>
-      amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
+      %0 = x86.amx.tile_zero : vector<2x16xf32>
+      x86.amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
       func.call @print(%arg0) : (memref<4x32xf32>) -> ()
     }
   }
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir
similarity index 96%
rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir
index 37db0333e3f5d..09ae8f1a95143 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
@@ -6,8 +6,8 @@
 // Note: To run this test, your CPU must support AMX.
 
 func.func @tilezero(%arg0: memref<?x?xi32>, %i: index, %j: index) {
-  %1 = amx.tile_zero : vector<16x16xi32>
-  amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
+  %1 = x86.amx.tile_zero : vector<16x16xi32>
+  x86.amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
   return
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir
index 5570da8fe04b4..c375350de50d4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir
@@ -5,14 +5,15 @@
 
 func.func @entry() -> i32 {
   %i0 = arith.constant 0 : i32
-  %i4 = arith.constant 4 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
 
   %a = arith.constant dense<[1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0]> : vector<8xf32>
   %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
   %r = x86.avx.intr.dot %a, %b : vector<8xf32>
 
-  %1 = vector.extract %r[%i0] : f32 from vector<8xf32>
-  %2 = vector.extract %r[%i4] : f32 from vector<8xf32>
+  %1 = vector.extract %r[%c0] : f32 from vector<8xf32>
+  %2 = vector.extract %r[%c4] : f32 from vector<8xf32>
   %d = arith.addf %1, %2 : f32
 
   // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 )
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir
index 4f3f70a45a507..7b0f505a47780 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir
@@ -180,8 +180,7 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
     -> f64 {
   // Helper constants for loops.
   %c0 = arith.constant 0 : index
-  %i0 = arith.constant 0 : i32
-  %i7 = arith.constant 7 : i32
+  %c7 = arith.constant 7 : index
   %c8 = arith.constant 8 : index
 
   %data_zero = arith.constant 0.0 : f64
@@ -196,13 +195,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
       iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
     %v_A = vector.transfer_read %m_A[%a], %index_padding
         : memref<?xi64>, vector<8xi64>
-    %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
+    %segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
 
     %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
         iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
       %v_C = vector.transfer_read %m_C[%b], %index_padding
           : memref<?xi64>, vector<8xi64>
-      %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
+      %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
       %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
 
       %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
@@ -251,8 +250,7 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
     -> f64 {
   // Helper constants for loops.
   %c0 = arith.constant 0 : index
-  %i0 = arith.constant 0 : i32
-  %i7 = arith.constant 7 : i32
+  %c7 = arith.constant 7 : index
   %c8 = arith.constant 8 : index
 
   %data_zero = arith.constant 0.0 : f64
@@ -273,10 +271,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
     %v_C = vector.transfer_read %m_C[%b1], %index_padding
         : memref<?xi64>, vector<8xi64>
 
-    %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
-    %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
-    %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64>
-    %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
+    %segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
+    %segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
+    %segB_min = vector.extract %v_C[%c0] : i64 from vector<8xi64>
+    %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
 
     %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
     %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
@@ -340,7 +338,7 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
     -> f64 {
   // Helper constants for loops.
   %c0 = arith.constant 0 : index
-  %i7 = arith.constant 7 : i32
+  %c7 = arith.constant 7 : index
   %c8 = arith.constant 8 : index
 
   %data_zero = arith.constant 0.0 : f64
@@ -370,8 +368,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
             -> f64
     %r2 = arith.addf %r1, %subresult : f64
 
-    %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
-    %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
+    %segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
+    %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
 
     %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64
     %cond_a_i64 = arith.extui %cond_a : i1 to i64
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 160a9ced46e21..4a4be24c2e3ab 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \
+// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
 // RUN: | mlir-translate --mlir-to-llvmir \
 // RUN: | FileCheck %s
 
@@ -7,8 +7,8 @@ func.func @amx_tile_zero(%out: memref<?x?xf32>, %idx: index)
 {
   // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
   // CHECK: call void @llvm.x86.tilestored64.internal
-  %zero = amx.tile_zero : !amx.tile<16x16xf32>
-  amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !amx.tile<16x16xf32>
+  %zero = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -18,8 +18,8 @@ func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>,
 {
   // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
   // CHECK: call void @llvm.x86.tilestored64.internal
-  %val = amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !amx.tile<16x64xi8>
+  %val = x86.amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  x86.amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
   return
 }
 
@@ -29,10 +29,10 @@ func.func @amx_tile_load_store_strided(%base: memref<?xi8>, %out: memref<?xi8>,
 {
   // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
   // CHECK: call void @llvm.x86.tilestored64.internal
-  %val = amx.tile_load %base[%idx], %stride
-    : memref<?xi8> into !amx.tile<16x64xi8>
-  amx.tile_store %out[%idx], %val, %stride
-    : memref<?xi8>, !amx.tile<16x64xi8>
+  %val = x86.amx.tile_load %base[%idx], %stride
+    : memref<?xi8> into !x86.amx.tile<16x64xi8>
+  x86.amx.tile_store %out[%idx], %val, %stride
+    : memref<?xi8>, !x86.amx.tile<16x64xi8>
   return
 }
 
@@ -42,15 +42,15 @@ func.func @amx_tile_mulf_bf16(
     %out: memref<?x?xf32>)
 {
   // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
-  %acc = amx.tile_zero : !amx.tile<16x16xf32>
+  %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
   // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
-  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
-  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
+  %tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
   // CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
-  %tRes = amx.tile_mulf %tA, %tB, %acc
-    : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  %tRes = x86.amx.tile_mulf %tA, %tB, %acc
+    : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
   // CHECK: call void @llvm.x86.tilestored64.internal
-  amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -60,15 +60,15 @@ func.func @amx_tile_mulf_f16(
     %out: memref<?x?xf32>)
 {
   // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
-  %acc = amx.tile_zero : !amx.tile<16x16xf32>
+  %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
   // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
-  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
-  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  %tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
+  %tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
   // CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal
-  %tRes = amx.tile_mulf %tA, %tB, %acc
-    : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+  %tRes = x86.amx.tile_mulf %tA, %tB, %acc
+    : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
     // CHECK: call void @llvm.x86.tilestored64.internal
-  amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
   return
 }
 
@@ -79,26 +79,26 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index
   // CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal
-  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  %acc = amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  %tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  %acc = x86.amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
   // CHECK: call x86_amx @llvm.x86.tdpbuud.internal
   // CHECK: call x86_amx @llvm.x86.tdpbssd.internal
   // CHECK: call x86_amx @llvm.x86.tdpbusd.internal
   // CHECK: call x86_amx @llvm.x86.tdpbsud.internal
-  %res = amx.tile_muli %tA zext, %tB zext, %acc
-    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  %res1 = amx.tile_muli %tA, %tB, %acc
-    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  %res2 = amx.tile_muli %tA zext, %tB, %acc
-    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-  %res3 = amx.tile_muli %tA, %tB zext, %acc
-    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %res = x86.amx.tile_muli %tA zext, %tB zext, %acc
+    : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  %res1 = x86.amx.tile_muli %tA, %tB, %acc
+    : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  %res2 = x86.amx.tile_muli %tA zext, %tB, %acc
+    : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
+  %res3 = x86.amx.tile_muli %tA, %tB zext, %acc
+    : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
   // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
-  amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !amx.tile<16x16xi32>
-  amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
+  x86.amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
+  x86.amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
   return
 }
 
@@ -108,16 +108,16 @@ func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
   cf.cond_br %cond, ^bb1, ^bb2
 ^bb1:  // pred: ^bb0
   // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
-  %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
-  cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+  %0 = x86.amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
+  cf.br ^bb3(%0 : !x86.amx.tile<16x64xi8>)
 ^bb2:  // pred: ^bb0
   // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
-  %1 = amx.tile_zero : !amx.tile<16x64xi8>
-  cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
-^bb3(%2: !amx.tile<16x64xi8>):  // 2 preds: ^bb1, ^bb2
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
+  cf.br ^bb3(%1 : !x86.amx.tile<16x64xi8>)
+^bb3(%2: !x86.amx.tile<16x64xi8>):  // 2 preds: ^bb1, ^bb2
   cf.br ^bb4
 ^bb4:  // pred: ^bb3
   // CHECK: call void @llvm.x86.tilestored64.internal
-  amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+  x86.amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
   return
 }
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index b14a11163c107..ba509b246866d 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -45,7 +45,7 @@ config.enable_spirv_cpu_runner = @MLIR_ENABLE_SPIRV_CPU_RUNNER@
 config.enable_vulkan_runner = @MLIR_ENABLE_VULKAN_RUNNER@
 config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
 config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@"
-config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@
+config.mlir_run_x86_amx_tests = @MLIR_RUN_X86_AMX_TESTS@
 config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
 # This is a workaround for the fact that LIT's:
 #   %if <cond>
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 9724eb42119a7..fb344880598c3 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -3,7 +3,6 @@
 // CHECK-SAME: acc
 // CHECK-SAME: affine
 // CHECK-SAME: amdgpu
-// CHECK-SAME: amx
 // CHECK-SAME: arith
 // CHECK-SAME: arm_neon
 // CHECK-SAME: arm_sme



More information about the Mlir-commits mailing list