[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)

Ilya Enkovich llvmlistbot at llvm.org
Mon Nov 4 12:13:47 PST 2024


https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/111197

>From 2a76b2d63ae39b5714ba927c3ae3a71790f47af9 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 2 Oct 2024 22:21:32 +0000
Subject: [PATCH 1/4] Utilize x86_amx type for AMX dialect in MLIR.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |   4 +
 mlir/include/mlir/Dialect/AMX/AMX.td          | 139 +++++++++++++-----
 mlir/include/mlir/Dialect/AMX/AMXDialect.h    |   3 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td |  12 ++
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |   1 +
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   7 +
 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp        |  63 ++++++--
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  51 ++++---
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp |   2 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |   6 +-
 mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp       |   2 +
 mlir/lib/Target/LLVMIR/TypeToLLVM.cpp         |   3 +
 mlir/test/Dialect/AMX/invalid.mlir            |  18 +--
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  |  52 ++++---
 mlir/test/Dialect/AMX/roundtrip.mlir          |  50 +++----
 15 files changed, 287 insertions(+), 126 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce836..bd4b3e73f074108 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
+  /// Convert AMX tile type x86_amx type.
+  Type convertAMXTileType(amx::TileType type) const;
+
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
 
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac0..8ef5ac25fbbddff 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 
 //===----------------------------------------------------------------------===//
 // AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
     For details, see the Intel documentation:
     https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
   }];
+  let useDefaultTypePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<AMX_Dialect, typeName, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+  let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+  let description = [{
+    This type is used to represent values in AMX tile registers. All AMX operations
+    work on AMX tiles and these tiles cannot be used in other operations directly.
+    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+    element type for IR verification and lowering to LLVMIR dialect.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    AMX_TileTypeElementType:$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this tile type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+      return get(shape.value_or(getShape()), elementType);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+                      "::mlir::amx::TileType">;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : vector<16x16xbf16>
+      %0 = amx.tile_zero : <16x16xbf16>
     ```
   }];
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+                   AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getVal().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getVal().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     Example:
 
     ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+      %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
-  let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+  let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+                       AMX2DTileOf<[F16, BF16]>:$rhs,
+                       AMX2DTileOf<[F32]>:$acc);
+  let results = (outs AMX2DTileOf<[F32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+      %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$rhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$acc,
+  let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+                       AMX2DTileOf<[I8]>:$rhs,
+                       AMX2DTileOf<[I32]>:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+  let results = (outs AMX2DTileOf<[I32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
                  AnyInteger,
 		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 // Dot product of i8 tiles into i32 tile (with sign/sign extension).
 def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
   Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea8..c0553ad8733fd45 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b441..09dd0919c318fb1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+  let summary = "LLVM x86_amx type.";
+  let description = [{
+    The x86_amx type represents a value held in an AMX tile register on an x86
+    machine. Can only be used in AMX intrinsics calls.
+  }];
+}
+
 #endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb4..39199e4affccfac 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAMXDialect
   MLIRIR
   MLIRLLVMDialect
   MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 4e7758bf46d9cfc..07e6dd55f855191 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       return std::nullopt;
     return llvmType;
   });
+  addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -594,6 +595,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   return vectorType;
 }
 
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+  return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
 /// Convert a type in the context of the default or bare pointer calling
 /// convention. Calling convention sensitive types, such as MemRefType and
 /// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d0..829f48e223383e9 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
 }
 
 /// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
   const unsigned kMaxRows = 16;
   const unsigned kBitsPerRow = 64 * 8;
   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
 }
 
 /// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
-                                     VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+                                     amx::TileType btp, amx::TileType ctp,
                                      unsigned scale) {
   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
 }
 
 LogicalResult amx::TileZeroOp::verify() {
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileLoadOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileStoreOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileMulFOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
 
 LogicalResult amx::TileMulIOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
   return success();
 }
 
+Type amx::TileType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return nullptr;
+
+  SmallVector<int64_t, 2> shape;
+  if (parser.parseDimensionList(shape, false, true))
+    return nullptr;
+
+  Type elementType;
+  if (parser.parseType(elementType))
+    return nullptr;
+
+  if (parser.parseGreater())
+    return nullptr;
+
+  return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+  os << "<";
+  os.printDimensionList(getShape());
+  os << 'x';
+  os.printType(getElementType());
+  os << '>';
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 46c7bfbf3ffcc2e..3172d93ef65bc99 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
 /// The second dimensions needs to be scaled by the number of bytes.
 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &typeConverter,
-                                     VectorType vType, Location loc) {
+                                     amx::TileType tType, Location loc) {
   Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
-  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
   return std::make_pair(
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -76,12 +76,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
                                                        tsz.second);
     return success();
@@ -95,10 +95,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     auto stride = getStride(rewriter, *getTypeConverter(), mType,
                             adaptor.getBase(), op.getLoc());
@@ -107,7 +107,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride.value());
     return success();
@@ -121,10 +121,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     auto stride = getStride(rewriter, *getTypeConverter(), mType,
                             adaptor.getBase(), op.getLoc());
@@ -144,9 +144,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -154,9 +154,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-        op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-        adaptor.getLhs(), adaptor.getRhs());
+    if (aType.getElementType().isBF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
     return success();
   }
 };
@@ -166,9 +171,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -208,8 +213,8 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
-                    x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
-                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+                    x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+                    x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38b..9537f7c40dd4bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
       .Case<LLVMArrayType>([&](Type) { return "array"; })
       .Case<LLVMStructType>([&](Type) { return "struct"; })
       .Case<LLVMTargetExtType>([&](Type) { return "target"; })
+      .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
       .Default([](Type) -> StringRef {
         llvm_unreachable("unexpected 'llvm' type kind");
       });
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })
       .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+      .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff94..1bed3fa48b30d7a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
       LLVMFixedVectorType,
       LLVMScalableVectorType,
       LLVMTargetExtType,
-      LLVMVoidType
+      LLVMVoidType,
+      LLVMX86AMXType
     >(type)) {
     // clang-format on
     return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
             LLVMMetadataType,
             LLVMPPCFP128Type,
             LLVMTokenType,
-            LLVMVoidType
+            LLVMVoidType,
+            LLVMX86AMXType
           >([](Type) { return true; })
           // clang-format on
           .Default([](Type) { return false; });
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e8331..ea990ca7aefbe03 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
       return Float128Type::get(&context);
     if (type->isX86_FP80Ty())
       return Float80Type::get(&context);
+    if (type->isX86_AMXTy())
+      return LLVM::LLVMX86AMXType::get(&context);
     if (type->isPPC_FP128Ty())
       return LLVM::LLVMPPCFP128Type::get(&context);
     if (type->isLabelTy())
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 659150272388010..c7a533eddce84be 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
             .Case([this](LLVM::LLVMMetadataType) {
               return llvm::Type::getMetadataTy(context);
             })
+            .Case([this](LLVM::LLVMX86AMXType) {
+              return llvm::Type::getX86_AMXTy(context);
+            })
             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d06..46f3b7448148b36 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : vector<17x16xbf16>
+  %0 = amx.tile_zero : <17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : vector<16x65xi8>
+  %0 = amx.tile_zero : <16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : vector<16x5xi8>
+  %0 = amx.tile_zero : <16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : vector<8x8xbf16>
-  %1 = amx.tile_zero : vector<8x8xbf16>
-  %2 = amx.tile_zero : vector<4x4xf32>
+  %0 = amx.tile_zero : <8x8xbf16>
+  %1 = amx.tile_zero : <8x8xbf16>
+  %2 = amx.tile_zero : <4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
 }
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 3cacbd0044f8251..82fb0a9c99f71de 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,33 +14,49 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_zero : <16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, <16x16xi32>
   return
 }
 
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
 // CHECK: amx.tilezero
 // CHECK: amx.tileloadd64
 // CHECK: amx.tileloadd64
 // CHECK: amx.tdpbf16ps
 // CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_zero : <16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_zero : <16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c3576..62bc071cab1e029 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : <16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+  %1 = amx.tile_zero : <16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
   return
 }

>From 3a07dfc460a533fbf86440655cef736d93357f88 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 7 Oct 2024 16:46:03 +0000
Subject: [PATCH 2/4] Fix review comments.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  2 +-
 mlir/include/mlir/Dialect/AMX/AMX.td          | 67 +++++++++++--------
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  4 +-
 mlir/test/Dialect/AMX/invalid.mlir            | 18 ++---
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  | 42 ++++++------
 mlir/test/Dialect/AMX/roundtrip.mlir          | 50 +++++++-------
 6 files changed, 97 insertions(+), 86 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index bd4b3e73f074108..6e01f70ee351235 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -259,7 +259,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
-  /// Convert AMX tile type x86_amx type.
+  /// Convert an AMX tile type to the x86_amx type.
   Type convertAMXTileType(amx::TileType type) const;
 
   /// Options for customizing the llvm lowering.
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8ef5ac25fbbddff..8a51df1ea183fc7 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
   let cppFunctionName = "isValidTileTypeElementType";
 }
 
-def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
   let summary = "AMX 2D tile to be used by AMX opertaions.";
 
   let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
   let skipDefaultBuilders = 1;
 }
 
-def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
-
-def IsAMX2DTilePred : And<[IsAMXTilePred,
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
   CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
 
-class AMX2DTileOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+class AMXTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
                       "::mlir::amx::TileType">;
 
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : <16x16xbf16>
+      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
     ```
   }];
-  let results = (outs
-    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "attr-dict `:` type($res)";
+  let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
-  let results = (outs
-    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
-                       "type($base) `into` type($res)";
+                       "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
+                   AnyAMXTile:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
-                       "type($base) `,` type($val)";
+                       "type($base) `,` qualified(type($val))";
   let hasVerifier = 1;
 }
 
@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     Example:
 
     ```mlir
-      %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+      %0 = amx.tile_mulf %a, %b, %c
+        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
     ```
   }];
-  let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
-                       AMX2DTileOf<[F16, BF16]>:$rhs,
-                       AMX2DTileOf<[F32]>:$acc);
-  let results = (outs AMX2DTileOf<[F32]>:$res);
+  let arguments = (ins AMXTileF16OrBF16:$lhs,
+                       AMXTileF16OrBF16:$rhs,
+                       AMXTileF32:$acc);
+  let results = (outs AMXTileF32:$res);
   let extraClassDeclaration = [{
     TileType getLhsTileType() {
       return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs))"
+                       " `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
+      %0 = amx.tile_muli %a zext, %b zext, %c
+        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
     ```
   }];
-  let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
-                       AMX2DTileOf<[I8]>:$rhs,
-                       AMX2DTileOf<[I32]>:$acc,
+  let arguments = (ins AMXTileI8:$lhs,
+                       AMXTileI8:$rhs,
+                       AMXTileI32:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs AMX2DTileOf<[I32]>:$res);
+  let results = (outs AMXTileI32:$res);
   let extraClassDeclaration = [{
     TileType getLhsTileType() {
       return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 3172d93ef65bc99..14fb8934040de9e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -158,10 +158,12 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
           adaptor.getLhs(), adaptor.getRhs());
-    else
+    else if (aType.getElementType().isF16())
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
           adaptor.getLhs(), adaptor.getRhs());
+    else
+      llvm_unreachable("Unexpected element type for amx.mulf");
     return success();
   }
 };
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 46f3b7448148b36..8febe1605e33a4a 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : <17x16xbf16>
+  %0 = amx.tile_zero : !amx.tile<17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : <16x65xi8>
+  %0 = amx.tile_zero : !amx.tile<16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : <16x5xi8>
+  %0 = amx.tile_zero : !amx.tile<16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : <8x8xbf16>
-  %1 = amx.tile_zero : <8x8xbf16>
-  %2 = amx.tile_zero : <4x4xf32>
+  %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %1 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %2 = amx.tile_zero : !amx.tile<4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
 }
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 82fb0a9c99f71de..cb827e6f8eca26a 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,17 +14,17 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, <16x16xi32>
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, !amx.tile<16x16xi32>
   return
 }
 
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
 // CHECK: amx.tilestored64
 func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_zero : !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: amx.tilestored64
 func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x32xf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_zero : !amx.tile<16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 62bc071cab1e029..1b7f781ae173d39 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : <16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
+// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
   return
 }

>From b6802a396db25f12a88cb284c4322c7b42aa3744 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Fri, 1 Nov 2024 21:19:33 +0000
Subject: [PATCH 3/4] Adjust new test.

---
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index cb827e6f8eca26a..8085f5f59fcaf01 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -79,11 +79,11 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
 // CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
 func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
-  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
-  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
+  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
+  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
+  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
   return
 }

>From 37a98e6b0a9e79f705c24fd75083227604a35ac8 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Fri, 1 Nov 2024 21:31:05 +0000
Subject: [PATCH 4/4] Remove amx::TileType conversion from LLVMTypeConverter.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  4 ---
 mlir/include/mlir/Dialect/AMX/Transforms.h    |  8 ++++--
 mlir/include/mlir/InitAllExtensions.h         |  2 ++
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |  1 -
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  7 ------
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 25 ++++++++++++++++++-
 6 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 6e01f70ee351235..d79b90f840ce836 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,7 +15,6 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -259,9 +258,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
-  /// Convert an AMX tile type to the x86_amx type.
-  Type convertAMXTileType(amx::TileType type) const;
-
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
 
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index d00ac52e274f9f7..7391ec2ff6b14a4 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -14,16 +14,20 @@ namespace mlir {
 class LLVMConversionTarget;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class DialectRegistry;
 
 /// Collect a set of patterns to lower AMX ops to ops that map to LLVM
 /// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns);
 
 /// Configure the target to support lowering AMX ops to ops that map to LLVM
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 2a241fa4b192fee..1f2ef26b450701a 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -24,6 +24,7 @@
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertNVVMToLLVMInterface(registry);
   registerConvertOpenMPToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  registerConvertAMXToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 39199e4affccfac..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,7 +12,6 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   Core
 
   LINK_LIBS PUBLIC
-  MLIRAMXDialect
   MLIRIR
   MLIRLLVMDialect
   MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 07e6dd55f855191..4e7758bf46d9cfc 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,7 +67,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       return std::nullopt;
     return llvmType;
   });
-  addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -595,12 +594,6 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   return vectorType;
 }
 
-/// Convert an AMX tile type to LLVM x86_amx type.
-/// Shape and element type of the tile are ignored.
-Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
-  return LLVM::LLVMX86AMXType::get(&getContext());
-}
-
 /// Convert a type in the context of the default or bare pointer calling
 /// convention. Calling convention sensitive types, such as MemRefType and
 /// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 14fb8934040de9e..4eac371d4c1ae4f 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/AMX/Transforms.h"
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/AMX/AMXDialect.h"
@@ -208,9 +209,12 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 } // namespace
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
                TileMulFConversion, TileMulIConversion>(converter);
+  converter.addConversion([&](amx::TileType type) {
+    return LLVM::LLVMX86AMXType::get(&converter.getContext());
+  });
 }
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
@@ -220,3 +224,22 @@ void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+    dialect->addInterfaces<AMXToLLVMDialectInterface>();
+  });
+}



More information about the Mlir-commits mailing list