[Mlir-commits] [mlir] [mlir][amx] Simplify intrinsic generation (PR #140559)

Adam Siemieniuk llvmlistbot at llvm.org
Mon May 19 08:26:44 PDT 2025


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/140559

Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7

>From 3e827e0934b98c151a4538c9b7d5667bf4d8991b Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 5 May 2025 14:03:46 +0200
Subject: [PATCH] [mlir][amx] Simplify intrinsic generation

Replaces separate amx named intrinsic operations with direct calls to
LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR
is eliminated. Instead, this step is now performed by the existing
llvm dialect infrastructure.

Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
---
 mlir/include/mlir/Dialect/AMX/AMX.td          | 157 ++++++------
 mlir/include/mlir/Dialect/AMX/AMXDialect.h    |   4 +
 .../include/mlir/Dialect/AMX/AMXInterfaces.td |  31 +++
 mlir/include/mlir/Dialect/AMX/CMakeLists.txt  |   5 +-
 mlir/include/mlir/Dialect/AMX/Transforms.h    |   3 -
 mlir/include/mlir/InitAllExtensions.h         |   2 -
 .../Dialect/AMX/AMXToLLVMIRTranslation.h      |  31 ---
 mlir/include/mlir/Target/LLVMIR/Dialect/All.h |   2 -
 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp        | 190 ++++++++++++++-
 mlir/lib/Dialect/AMX/IR/CMakeLists.txt        |   1 +
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |   3 -
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 224 ++----------------
 mlir/lib/Target/LLVMIR/CMakeLists.txt         |   1 -
 .../Dialect/AMX/AMXToLLVMIRTranslation.cpp    |  56 -----
 .../Target/LLVMIR/Dialect/AMX/CMakeLists.txt  |  16 --
 mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt |   1 -
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  |  54 ++---
 mlir/test/Target/LLVMIR/amx.mlir              |  97 +++++++-
 18 files changed, 432 insertions(+), 446 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
 delete mode 100644 mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
 delete mode 100644 mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
 delete mode 100644 mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8a51df1ea183f..a484f2ca009a2 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -25,10 +25,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef AMX
-#define AMX
+#ifndef AMX_OPS
+#define AMX_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/AMX/AMXInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
 
     This `AMX` dialect provides a bridge between MLIR concepts such as
     vectors and memrefs and the lower level LLVM IR support of AMX.
-    The dialect is split into user-facing AMX ops (AMX_Op) and
-    backend-facing intrinsic ops (AMX_IntrOp).
 
     Note that since configuration changes (implicit at dialect level) are
     costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
 class AMX_Op<string mnemonic, list<Trait> traits = []> :
   Op<AMX_Dialect, mnemonic, traits> {}
 
-// The "internal" intrinsics are meant for compiler usage.
-class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
-  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
-                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
-                  [], [], traits, numResults>;
-
 //===----------------------------------------------------------------------===//
-// AMX Op definitions (user facing).
+// AMX Op definitions
 //===----------------------------------------------------------------------===//
 
 //
 // Tile reset.
 //
 
-def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
+def TileZeroOp : AMX_Op<"tile_zero", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile zero operation";
   let description = [{
     Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilezero.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
 // Tile memory operations.
 //
 
-def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
+def TileLoadOp : AMX_Op<"tile_load", [Pure,
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile load operation";
   let description = [{
     Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tileloadd64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
                        "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
-def TileStoreOp : AMX_Op<"tile_store"> {
+def TileStoreOp : AMX_Op<"tile_store", [
+    AMXIntrinsicOpInterface
+  ]> {
   let summary = "tile store operation";
   let description = [{
     Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     TileType getTileType() {
       return ::llvm::cast<TileType>(getVal().getType());
     }
+
+    std::string getIntrinsicName() {
+      return "llvm.x86.tilestored64.internal";
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
                        "type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
 // Tile arithmetic operations.
 //
 
-def TileMulFOp : AMX_Op<"tile_mulf", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (floating-point)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,6 +295,19 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdp";
+      auto elementType =
+        getLhsTileType().getElementType();
+      intr += elementType.isF16() ? "fp16" : "bf16";
+      intr += "ps.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs))"
@@ -277,8 +315,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
   let hasVerifier = 1;
 }
 
-def TileMulIOp : AMX_Op<"tile_muli", [
-    Pure, AllTypesMatch<["acc", "res"]>]> {
+def TileMulIOp : AMX_Op<"tile_muli", [Pure,
+    AMXIntrinsicOpInterface,
+    AllTypesMatch<["acc", "res"]>
+  ]> {
   let summary = "tile multiplication operation (integer)";
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
+
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.tdpb";
+      intr += getIsZextLhs() ? "u" : "s";
+      intr += getIsZextRhs() ? "u" : "s";
+      intr += "d.internal";
+      return intr;
+    }
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
                        "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// AMX IntrOp definitions (LLVM compiler facing).
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset. Parameters define the tile size.
-//
-
-def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
-  Arguments<(ins AnyInteger, AnyInteger)>;
-
-//
-// Tile memory operations. Parameters define the tile size,
-// base address, and stride between consecutive rows for the
-// memory operation.
-//
-
-def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger)>;
-
-def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
-
-//
-// Tile multiplication operations (series of dot products). Parameters
-// define the tile sizes and source and destination tiles for the
-// operation. Note that the prefix "tdp" stands for tile dot product.
-//
-
-// Dot product of bf16 tiles into f32 tile.
-def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of f16 tiles into f32 tile.
-def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/sign extension).
-def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with sign/zero extension).
-def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/sign extension).
-def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-// Dot product of i8 tiles into i32 tile (with zero/zero extension).
-def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
-  Arguments<(ins AnyInteger,
-                 AnyInteger,
-		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
-
-#endif // AMX
+#endif // AMX_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index c0553ad8733fd..c79f31d4c994a 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -14,11 +14,15 @@
 #define MLIR_DIALECT_AMX_AMXDIALECT_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+/// Include the generated interface declarations.
+#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
 #define GET_TYPEDEF_CLASSES
diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
new file mode 100644
index 0000000000000..012d1ba7368f7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
@@ -0,0 +1,31 @@
+//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines interfaces for the AMX dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_INTERFACES
+#define AMX_INTERFACES
+
+include "mlir/IR/Interfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// AMX Intrinsic Interface
+//===----------------------------------------------------------------------===//
+
+def AMXIntrinsicOpInterface
+    : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
+  let description = [{
+    A wrapper interface for operations representing AMX LLVM intrinsics.
+  }];
+  let cppNamespace = "::mlir::amx";
+}
+
+#endif // AMX_INTERFACES
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f3f1aff5a6360..f875c78d240cc 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect(AMX amx)
 add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
-set(LLVM_TARGET_DEFINITIONS AMX.td)
-mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRAMXConversionsIncGen)
+add_mlir_interface(AMXInterfaces)
+add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..4a751d99ceeee 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
-/// Register LLVM conversion interface for AMX dialect.
-void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..1e3f7c649a8bd 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,7 +32,6 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
-  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
deleted file mode 100644
index 4525ec3212196..0000000000000
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXToLLVMIRTranslation.h - AMX to LLVM IR ----------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This provides registration calls for AMX dialect to LLVM IR translation.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-#define MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
-
-namespace mlir {
-
-class DialectRegistry;
-class MLIRContext;
-
-/// Register the AMX dialect and the translation from it to the LLVM IR
-/// in the given registry;
-void registerAMXDialectTranslation(DialectRegistry &registry);
-
-/// Register the AMX dialect and the translation from it in the registry
-/// associated with the given context.
-void registerAMXDialectTranslation(MLIRContext &context);
-
-} // namespace mlir
-
-#endif // MLIR_TARGET_LLVMIR_DIALECT_AMX_AMXTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index e043ff2f6825c..60615cf601655 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
 
-#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
@@ -37,7 +36,6 @@ class DialectRegistry;
 /// corresponding translation interfaces.
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
-  registerAMXDialectTranslation(registry);
   registerArmSMEDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 829f48e223383..69f524e1c311d 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -11,6 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -21,6 +23,8 @@
 
 using namespace mlir;
 
+#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
+
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
@@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
   return success();
 }
 
+/// Get pointer to a memref descriptor.
+/// Optionally, the base pointer can be offset using linearized index computed
+/// from the given indices.
+static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
+                          ValueRange indices,
+                          const LLVMTypeConverter &typeConverter,
+                          RewriterBase &rewriter) {
+  auto [strides, offset] = type.getStridesAndOffset();
+
+  MemRefDescriptor memRefDescriptor(buffer);
+  Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
+
+  int numIndices = indices.size();
+  if (numIndices == 0)
+    return base;
+
+  assert(type.getRank() == numIndices &&
+         "expects number of indices equal to memref rank");
+  Value index;
+  Type indexType = typeConverter.getIndexType();
+  for (int i = 0; i < numIndices; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : rewriter.create<LLVM::ConstantOp>(
+                    loc, indexType, rewriter.getIndexAttr(strides[i]));
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+
+  Type elementPtrType = memRefDescriptor.getElementPtrType();
+  return rewriter.create<LLVM::GEPOp>(
+      loc, elementPtrType, typeConverter.convertType(type.getElementType()),
+      base, index);
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
+                                       RewriterBase &rewriter) {
+  Type llvmInt16Type = rewriter.getIntegerType(16);
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+  return SmallVector<Value>{
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+/// Returns failure if proper stride couldn't be found.
+static Value getStride(Location loc, MemRefType mType, Value base,
+                       RewriterBase &rewriter) {
+  assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
+  int64_t preLast = mType.getRank() - 2;
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto [strides, offset] = mType.getStridesAndOffset();
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the stride at runtime.
+    MemRefDescriptor memrefDescriptor(base);
+    auto attr = rewriter.getI64IntegerAttr(bytes);
+    Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+    return rewriter
+        .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
+                             memrefDescriptor.stride(rewriter, loc, preLast))
+        .getResult();
+  }
+  // Use direct constant for static stride.
+  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
+  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
+      .getResult();
+}
+
 LogicalResult amx::TileZeroOp::verify() {
   return verifyTileSize(*this, getTileType());
 }
 
+SmallVector<Value>
+amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                      const LLVMTypeConverter &typeConverter,
+                                      RewriterBase &rewriter) {
+  return getTileSizes(getLoc(), getTileType(), rewriter);
+}
+
 LogicalResult amx::TileLoadOp::verify() {
-  unsigned rank = getMemRefType().getRank();
+  MemRefType memrefTy = getMemRefType();
+  unsigned rank = memrefTy.getRank();
+  if (rank < 2)
+    return emitOpError("requires at least 2D memref");
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+      strides.back() != 1)
+    return emitOpError("requires memref with unit innermost stride");
   return verifyTileSize(*this, getTileType());
 }
 
+SmallVector<Value>
+amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                      const LLVMTypeConverter &typeConverter,
+                                      RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  SmallVector<Value> intrinsicOperands;
+  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
+  intrinsicOperands.push_back(
+      getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
+                   adaptor.getIndices(), typeConverter, rewriter));
+  intrinsicOperands.push_back(
+      getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+
+  return intrinsicOperands;
+}
+
 LogicalResult amx::TileStoreOp::verify() {
-  unsigned rank = getMemRefType().getRank();
+  MemRefType memrefTy = getMemRefType();
+  unsigned rank = memrefTy.getRank();
+  if (rank < 2)
+    return emitOpError("requires at least 2D memref");
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+      strides.back() != 1)
+    return emitOpError("requires memref with unit innermost stride");
   return verifyTileSize(*this, getTileType());
 }
 
+SmallVector<Value>
+amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                       const LLVMTypeConverter &typeConverter,
+                                       RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  SmallVector<Value> intrinsicOperands;
+  intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
+  intrinsicOperands.push_back(
+      getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
+                   adaptor.getIndices(), typeConverter, rewriter));
+  intrinsicOperands.push_back(
+      getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+  intrinsicOperands.push_back(adaptor.getVal());
+
+  return intrinsicOperands;
+}
+
 LogicalResult amx::TileMulFOp::verify() {
   amx::TileType aType = getLhsTileType();
   amx::TileType bType = getRhsTileType();
@@ -95,6 +243,25 @@ LogicalResult amx::TileMulFOp::verify() {
   return success();
 }
 
+SmallVector<Value>
+amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                      const LLVMTypeConverter &typeConverter,
+                                      RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
+  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
+
+  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
+                                          tsza[1],          adaptor.getAcc(),
+                                          adaptor.getLhs(), adaptor.getRhs()};
+
+  return intrinsicOperands;
+}
+
 LogicalResult amx::TileMulIOp::verify() {
   amx::TileType aType = getLhsTileType();
   amx::TileType bType = getRhsTileType();
@@ -112,6 +279,25 @@ LogicalResult amx::TileMulIOp::verify() {
   return success();
 }
 
+SmallVector<Value>
+amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                      const LLVMTypeConverter &typeConverter,
+                                      RewriterBase &rewriter) {
+  auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
+
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
+  SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
+
+  SmallVector<Value> intrinsicOperands = {tsza[0],          tszb[1],
+                                          tsza[1],          adaptor.getAcc(),
+                                          adaptor.getLhs(), adaptor.getRhs()};
+
+  return intrinsicOperands;
+}
+
 Type amx::TileType::parse(AsmParser &parser) {
   if (parser.parseLess())
     return nullptr;
diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
index d109547b2438b..b6e2759843d5e 100644
--- a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRAMXDialect
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index 29340d4f45dd1..e827bc475e930 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,9 +1,6 @@
 add_mlir_dialect_library(MLIRAMXTransforms
   LegalizeForLLVMExport.cpp
 
-  DEPENDS
-  MLIRAMXConversionsIncGen
-
   LINK_LIBS PUBLIC
   MLIRAMXDialect
   MLIRIR
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4cb777b03b196..7471dc797e0fc 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -21,224 +21,42 @@ using namespace mlir::amx;
 
 namespace {
 
-/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
-/// dimension directly translates into the number of rows of the tiles.
-/// The second dimensions needs to be scaled by the number of bytes.
-std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
-                                     const LLVMTypeConverter &typeConverter,
-                                     amx::TileType tType, Location loc) {
-  Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
-  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
-  assert(llvm::isPowerOf2_64(width) && width >= 8);
-  unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
-  return std::make_pair(
-      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
-      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
-}
-
-/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
-/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-/// Returns failure if proper stride couldn't be found.
-FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
-                           const LLVMTypeConverter &typeConverter,
-                           MemRefType mType, Value base, Location loc) {
-  if (mType.getRank() < 2)
-    return failure();
-  int64_t preLast = mType.getRank() - 2;
-  Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
-  unsigned width = mType.getElementType().getIntOrFloatBitWidth();
-  assert(llvm::isPowerOf2_64(width) && width >= 8);
-  unsigned bytes = width >> 3;
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
-    return failure();
-  if (strides[preLast] == ShapedType::kDynamic) {
-    // Dynamic stride needs code to compute the stride at runtime.
-    MemRefDescriptor memrefDescriptor(base);
-    auto attr = rewriter.getI64IntegerAttr(bytes);
-    Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
-    return rewriter
-        .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
-                             memrefDescriptor.stride(rewriter, loc, preLast))
-        .getResult();
-  }
-  // Use direct constant for static stride.
-  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
-      .getResult();
-}
-
-struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
-  using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    amx::TileType tType = op.getTileType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
-    // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(tType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
-                                                       tsz.second);
-    return success();
-  }
-};
-
-struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
-  using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    MemRefType mType = op.getMemRefType();
-    amx::TileType tType = op.getTileType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
-    // Determine stride.
-    auto stride = getStride(rewriter, *getTypeConverter(), mType,
-                            adaptor.getBase(), op.getLoc());
-    if (failed(stride))
-      return failure();
-    // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
-    Type resType = typeConverter->convertType(tType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
-        op, resType, tsz.first, tsz.second, ptr, stride.value());
-    return success();
-  }
-};
-
-struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-  using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    MemRefType mType = op.getMemRefType();
-    amx::TileType tType = op.getTileType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
-    // Determine stride.
-    auto stride = getStride(rewriter, *getTypeConverter(), mType,
-                            adaptor.getBase(), op.getLoc());
-    if (failed(stride))
-      return failure();
-    // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
-        op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
-    return success();
-  }
-};
+/// Generic one-to-one conversion of simply mappable operations into calls
+/// to their respective LLVM intrinsics.
+struct AMXIntrinsicOpConversion
+    : public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
+  using OpInterfaceConversionPattern<
+      amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
+
+  AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
+                           PatternBenefit benefit = 1)
+      : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+                                     benefit),
+        typeConverter(typeConverter) {}
 
-struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
-  using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
   LogicalResult
-  matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
+  matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    amx::TileType aType = op.getLhsTileType();
-    amx::TileType bType = op.getRhsTileType();
-    amx::TileType cType = op.getTileType();
-    // Determine m x n x k tile sizes.
-    std::pair<Value, Value> tsza =
-        getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
-    std::pair<Value, Value> tszb =
-        getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
-    // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(cType);
-    if (aType.getElementType().isBF16())
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    else if (aType.getElementType().isF16())
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    else
-      llvm_unreachable("Unexpected element type for amx.mulf");
-    return success();
+    return LLVM::detail::intrinsicRewrite(
+        op, rewriter.getStringAttr(op.getIntrinsicName()),
+        op.getIntrinsicOperands(operands, typeConverter, rewriter),
+        typeConverter, rewriter);
   }
-};
 
-struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
-  using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    amx::TileType aType = op.getLhsTileType();
-    amx::TileType bType = op.getRhsTileType();
-    amx::TileType cType = op.getTileType();
-    // Determine m x n x k tile sizes.
-    std::pair<Value, Value> tsza =
-        getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
-    std::pair<Value, Value> tszb =
-        getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
-    // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(cType);
-    bool zexta = op.getIsZextLhs();
-    bool zextb = op.getIsZextRhs();
-    if (zexta && zextb)
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    else if (zexta && !zextb)
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    else if (!zexta && zextb)
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    else
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
-          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-          adaptor.getLhs(), adaptor.getRhs());
-    return success();
-  }
+private:
+  const LLVMTypeConverter &typeConverter;
 };
 
 } // namespace
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
-               TileMulFConversion, TileMulIConversion>(converter);
+  patterns.add<AMXIntrinsicOpConversion>(converter);
   converter.addConversion([&](amx::TileType type) {
     return LLVM::LLVMX86AMXType::get(&converter.getContext());
   });
 }
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
-  target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
-                    x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
-                    x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
-  target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
-                      TileMulFOp>();
-}
-
-namespace {
-/// Implement the interface to convert AMX to LLVM.
-struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
-  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
-
-  void populateConvertToLLVMConversionPatterns(
-      ConversionTarget &target, LLVMTypeConverter &typeConverter,
-      RewritePatternSet &patterns) const final {
-    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
-  }
-};
-} // namespace
-
-void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
-    dialect->addInterfaces<AMXToLLVMDialectInterface>();
-  });
+  target.addIllegalDialect<AMXDialect>();
 }
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 4ace3964e8ae0..af22a7ff04bf0 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -51,7 +51,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
   MLIRArmNeonToLLVMIRTranslation
   MLIRArmSMEToLLVMIRTranslation
   MLIRArmSVEToLLVMIRTranslation
-  MLIRAMXToLLVMIRTranslation
   MLIRBuiltinToLLVMIRTranslation
   MLIRGPUToLLVMIRTranslation
   MLIRLLVMToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
deleted file mode 100644
index 044462d33cfd1..0000000000000
--- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
+++ /dev/null
@@ -1,56 +0,0 @@
-//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a translation between the AMX dialect and LLVM IR.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/IntrinsicsX86.h"
-
-using namespace mlir;
-using namespace mlir::LLVM;
-
-namespace {
-/// Implementation of the dialect interface that converts operations belonging
-/// to the AMX dialect to LLVM IR.
-class AMXDialectLLVMIRTranslationInterface
-    : public LLVMTranslationDialectInterface {
-public:
-  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
-
-  /// Translates the given operation to LLVM IR using the provided IR builder
-  /// and saving the state in `moduleTranslation`.
-  LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
-                   LLVM::ModuleTranslation &moduleTranslation) const final {
-    Operation &opInst = *op;
-#include "mlir/Dialect/AMX/AMXConversions.inc"
-
-    return failure();
-  }
-};
-} // namespace
-
-void mlir::registerAMXDialectTranslation(DialectRegistry &registry) {
-  registry.insert<amx::AMXDialect>();
-  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
-    dialect->addInterfaces<AMXDialectLLVMIRTranslationInterface>();
-  });
-}
-
-void mlir::registerAMXDialectTranslation(MLIRContext &context) {
-  DialectRegistry registry;
-  registerAMXDialectTranslation(registry);
-  context.appendDialectRegistry(registry);
-}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
deleted file mode 100644
index 733b4c2e31b80..0000000000000
--- a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_translation_library(MLIRAMXToLLVMIRTranslation
-  AMXToLLVMIRTranslation.cpp
-
-  DEPENDS
-  MLIRAMXConversionsIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRAMXDialect
-  MLIRLLVMDialect
-  MLIRSupport
-  MLIRTargetLLVMIRExport
-  )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index 40df6e3f4b642..f030fa78942d5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -1,7 +1,6 @@
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
 add_subdirectory(ArmSVE)
-add_subdirectory(AMX)
 add_subdirectory(Builtin)
 add_subdirectory(GPU)
 add_subdirectory(LLVMIR)
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 8085f5f59fcaf..7e562b00a46a9 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -1,17 +1,17 @@
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: muli(
-// CHECK: amx.tilezero
-// CHECK: amx.tileloadd64
-// CHECK: amx.tileloadd64
-// CHECK: amx.tdpbuud
-// CHECK: amx.tilestored64
-// CHECK: amx.tdpbssd
-// CHECK: amx.tilestored64
-// CHECK: amx.tdpbusd
-// CHECK: amx.tilestored64
-// CHECK: amx.tdpbsud
-// CHECK: amx.tilestored64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbuud.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbssd.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbusd.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbsud.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
   %1 = amx.tile_zero : !amx.tile<16x64xi8>
@@ -29,11 +29,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
 }
 
 // CHECK-LABEL: mulbf16(
-// CHECK: amx.tilezero
-// CHECK: amx.tileloadd64
-// CHECK: amx.tileloadd64
-// CHECK: amx.tdpbf16ps
-// CHECK: amx.tilestored64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf16ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   %1 = amx.tile_zero : !amx.tile<16x32xbf16>
@@ -45,11 +45,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 }
 
 // CHECK-LABEL: mulfp16(
-// CHECK: amx.tilezero
-// CHECK: amx.tileloadd64
-// CHECK: amx.tileloadd64
-// CHECK: amx.tdpfp16ps
-// CHECK: amx.tilestored64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpfp16ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
 func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   %1 = amx.tile_zero : !amx.tile<16x32xf16>
@@ -62,21 +62,21 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
 
 // CHECK-LABEL: strides(
 // CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
 // CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
 // CHECK: llvm.mlir.constant(2 : i64) : i64
 // CHECK: llvm.extractvalue %{{.+}}[4, 0]
 // CHECK: %[[STRIDE_1:.+]] = llvm.mul
-// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
 // CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
 // CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
 // CHECK: llvm.mlir.constant(2 : i64) : i64
 // CHECK: llvm.extractvalue %{{.+}}[4, 0]
 // CHECK: %[[STRIDE_2:.+]] = llvm.mul
-// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
 func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
   %0 = arith.constant 0 : index
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 0281dfcd6ad69..094475040436d 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -1,13 +1,90 @@
-// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \
+// RUN: | mlir-translate --mlir-to-llvmir \
+// RUN: | FileCheck %s
 
-// CHECK-LABEL: define void @target(ptr %0)
-// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16)
-// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, ptr %0, i64 32, x86_amx %[[c]]
-llvm.func @target(%ptr: !llvm.ptr) {
-  %c = llvm.mlir.constant(16 : i16) : i16
-  %s = llvm.mlir.constant(32 : i64) : i64
-  %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>>
-  "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> ()
-  llvm.return
+// CHECK-LABEL: define void @amx_tile_zero
+func.func @amx_tile_zero(%out: memref<?x?xf32>, %idx: index)
+{
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  %zero = amx.tile_zero : !amx.tile<16x16xf32>
+  amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !amx.tile<16x16xf32>
+  return
 }
 
+// CHECK-LABEL: define void @amx_tile_load_store
+func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>,
+    %idx: index)
+{
+  // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  %val = amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !amx.tile<16x64xi8>
+  return
+}
+
+// CHECK-LABEL: define void @amx_tile_mulf_bf16
+func.func @amx_tile_mulf_bf16(
+    %matA: memref<?x?xbf16>, %matB: memref<?x?xbf16>, %idx: index,
+    %out: memref<?x?xf32>)
+{
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %acc = amx.tile_zero : !amx.tile<16x16xf32>
+  // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
+  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  // CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
+  %tRes = amx.tile_mulf %tA, %tB, %acc
+    : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: define void @amx_tile_mulf_f16
+func.func @amx_tile_mulf_f16(
+    %matA: memref<?x?xf16>, %matB: memref<?x?xf16>, %idx: index,
+    %out: memref<?x?xf32>)
+{
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %acc = amx.tile_zero : !amx.tile<16x16xf32>
+  // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
+  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  // CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal
+  %tRes = amx.tile_mulf %tA, %tB, %acc
+    : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+    // CHECK: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: define void @amx_tile_muli
+func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
+    %matC: memref<?x?xi32>, %idx: index, %out: memref<?x?xi8>)
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  // CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal
+  %tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %acc = amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  // CHECK: call x86_amx @llvm.x86.tdpbuud.internal
+  // CHECK: call x86_amx @llvm.x86.tdpbssd.internal
+  // CHECK: call x86_amx @llvm.x86.tdpbusd.internal
+  // CHECK: call x86_amx @llvm.x86.tdpbsud.internal
+  %res = amx.tile_muli %tA zext, %tB zext, %acc
+    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %res1 = amx.tile_muli %tA, %tB, %acc
+    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %res2 = amx.tile_muli %tA zext, %tB, %acc
+    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %res3 = amx.tile_muli %tA, %tB zext, %acc
+    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
+  return
+}



More information about the Mlir-commits mailing list