[Mlir-commits] [mlir] 2f743ac - [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (#111197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 6 06:30:58 PST 2024


Author: Ilya Enkovich
Date: 2024-11-06T14:30:53Z
New Revision: 2f743ac52e945e155ff3cb1f8ca5287b306b831e

URL: https://github.com/llvm/llvm-project/commit/2f743ac52e945e155ff3cb1f8ca5287b306b831e
DIFF: https://github.com/llvm/llvm-project/commit/2f743ac52e945e155ff3cb1f8ca5287b306b831e.diff

LOG: [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (#111197)

This patch is intended to resolve #109481 and improve the usability of
the AMX dialect.

In LLVM IR, AMX intrinsics use `x86_amx` which is one of the primitive
types. This type is supposed to be used for AMX intrinsic calls and no
other operations. AMX dialect of MLIR uses regular 2D vector types,
which are then lowered to arrays of vectors in the LLVMIR dialect. This
creates an inconsistency in the types used in the LLVMIR dialect and
LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require
result types to match and that is where tile loads and mul operation
results get `x86_amx` type. This works in very simple cases when mul and
tile store operations directly consume the result of another AMX
intrinsic call, but it doesn't work when an argument is a block argument
(phi node).

In addition to translation problems, this inconsistency between types
used in MLIR and LLVM IR makes MLIR verification and transformation
quite problematic. Both `amx.tileload` and `vector::transfer_read` can
load values of the same type, but only one of them can be used in AMX
operations. In general, by looking at a type of value, we cannot
determine if it can only be used for AMX operations or contrary can be
used in other operations but AMX ones.

To remove this inconsistency and make AMX operations more explicit in
their limitations, I propose to add `LLVMX86AMXType` type to the LLVMIR
dialect to match `x86_amx` type in LLVM IR, and introduce
`amx::TileType` to be used by AMX operations in MLIR. This resolves
translation problems for AMX usage with phi nodes and provides proper
type verification in MLIR for AMX operations.

P.S. This patch also adds missing FP16 support. It's trivial but
unrelated to type system changes, so let me know if I should submit it
separately.

---------

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMX/AMX.td
    mlir/include/mlir/Dialect/AMX/AMXDialect.h
    mlir/include/mlir/Dialect/AMX/Transforms.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/include/mlir/InitAllExtensions.h
    mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
    mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
    mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
    mlir/test/Dialect/AMX/invalid.mlir
    mlir/test/Dialect/AMX/legalize-for-llvm.mlir
    mlir/test/Dialect/AMX/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..8a51df1ea183fc 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 
 //===----------------------------------------------------------------------===//
 // AMX dialect definition.
@@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
     For details, see the Intel documentation:
     https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
   }];
+  let useDefaultTypePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<AMX_Dialect, typeName, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+  let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+  let description = [{
+    This type is used to represent values in AMX tile registers. All AMX operations
+    work on AMX tiles and these tiles cannot be used in other operations directly.
+    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+    element type for IR verification and lowering to LLVMIR dialect.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    AMX_TileTypeElementType:$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this tile type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+      return get(shape.value_or(getShape()), elementType);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
+  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMXTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
+                      "::mlir::amx::TileType">;
+
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : vector<16x16xbf16>
+      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
     ```
   }];
-  let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "attr-dict `:` type($res)";
+  let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
-  let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
-                       "type($base) `into` type($res)";
+                       "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+                   AnyAMXTile:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getVal().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getVal().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
-                       "type($base) `,` type($val)";
+                       "type($base) `,` qualified(type($val))";
   let hasVerifier = 1;
 }
 
@@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
 
     ```mlir
       %0 = amx.tile_mulf %a, %b, %c
-        : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
-  let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+  let arguments = (ins AMXTileF16OrBF16:$lhs,
+                       AMXTileF16OrBF16:$rhs,
+                       AMXTileF32:$acc);
+  let results = (outs AMXTileF32:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs))"
+                       " `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
@@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [
 
     ```mlir
       %0 = amx.tile_muli %a zext, %b zext, %c
-        : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$rhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$acc,
+  let arguments = (ins AMXTileI8:$lhs,
+                       AMXTileI8:$rhs,
+                       AMXTileI32:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+  let results = (outs AMXTileI32:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
@@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
                  AnyInteger,
 		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 // Dot product of i8 tiles into i32 tile (with sign/sign extension).
 def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
   Arguments<(ins AnyInteger,

diff  --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea..c0553ad8733fd4 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index d00ac52e274f9f..7391ec2ff6b14a 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -14,16 +14,20 @@ namespace mlir {
 class LLVMConversionTarget;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class DialectRegistry;
 
 /// Collect a set of patterns to lower AMX ops to ops that map to LLVM
 /// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns);
 
 /// Configure the target to support lowering AMX ops to ops that map to LLVM
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b44..09dd0919c318fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+  let summary = "LLVM x86_amx type.";
+  let description = [{
+    The x86_amx type represents a value held in an AMX tile register on an x86
+    machine. Can only be used in AMX intrinsics calls.
+  }];
+}
+
 #endif // LLVMTYPES_TD

diff  --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 2a241fa4b192fe..1f2ef26b450701 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -24,6 +24,7 @@
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertNVVMToLLVMInterface(registry);
   registerConvertOpenMPToLLVMInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  registerConvertAMXToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);

diff  --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d..829f48e223383e 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
 }
 
 /// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
   const unsigned kMaxRows = 16;
   const unsigned kBitsPerRow = 64 * 8;
   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
 }
 
 /// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
-                                     VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+                                     amx::TileType btp, amx::TileType ctp,
                                      unsigned scale) {
   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
 }
 
 LogicalResult amx::TileZeroOp::verify() {
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileLoadOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileStoreOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileMulFOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
 
 LogicalResult amx::TileMulIOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
   return success();
 }
 
+Type amx::TileType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return nullptr;
+
+  SmallVector<int64_t, 2> shape;
+  if (parser.parseDimensionList(shape, false, true))
+    return nullptr;
+
+  Type elementType;
+  if (parser.parseType(elementType))
+    return nullptr;
+
+  if (parser.parseGreater())
+    return nullptr;
+
+  return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+  os << "<";
+  os.printDimensionList(getShape());
+  os << 'x';
+  os.printType(getElementType());
+  os << '>';
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"

diff  --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 46c7bfbf3ffcc2..4eac371d4c1ae4 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/AMX/Transforms.h"
 
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/AMX/AMXDialect.h"
@@ -25,13 +26,13 @@ namespace {
 /// The second dimensions needs to be scaled by the number of bytes.
 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &typeConverter,
-                                     VectorType vType, Location loc) {
+                                     amx::TileType tType, Location loc) {
   Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
-  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
   return std::make_pair(
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -76,12 +77,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
                                                        tsz.second);
     return success();
@@ -95,10 +96,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     auto stride = getStride(rewriter, *getTypeConverter(), mType,
                             adaptor.getBase(), op.getLoc());
@@ -107,7 +108,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride.value());
     return success();
@@ -121,10 +122,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     auto stride = getStride(rewriter, *getTypeConverter(), mType,
                             adaptor.getBase(), op.getLoc());
@@ -144,9 +145,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -154,9 +155,16 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-        op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-        adaptor.getLhs(), adaptor.getRhs());
+    if (aType.getElementType().isBF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else if (aType.getElementType().isF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else
+      llvm_unreachable("Unexpected element type for amx.mulf");
     return success();
   }
 };
@@ -166,9 +174,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -201,15 +209,37 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 } // namespace
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
                TileMulFConversion, TileMulIConversion>(converter);
+  converter.addConversion([&](amx::TileType type) {
+    return LLVM::LLVMX86AMXType::get(&converter.getContext());
+  });
 }
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
-                    x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
-                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+                    x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+                    x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+    dialect->addInterfaces<AMXToLLVMDialectInterface>();
+  });
+}

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38..9537f7c40dd4be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
       .Case<LLVMArrayType>([&](Type) { return "array"; })
       .Case<LLVMStructType>([&](Type) { return "struct"; })
       .Case<LLVMTargetExtType>([&](Type) { return "target"; })
+      .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
       .Default([](Type) -> StringRef {
         llvm_unreachable("unexpected 'llvm' type kind");
       });
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })
       .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+      .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff9..1bed3fa48b30d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
       LLVMFixedVectorType,
       LLVMScalableVectorType,
       LLVMTargetExtType,
-      LLVMVoidType
+      LLVMVoidType,
+      LLVMX86AMXType
     >(type)) {
     // clang-format on
     return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
             LLVMMetadataType,
             LLVMPPCFP128Type,
             LLVMTokenType,
-            LLVMVoidType
+            LLVMVoidType,
+            LLVMX86AMXType
           >([](Type) { return true; })
           // clang-format on
           .Default([](Type) { return false; });

diff  --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e833..ea990ca7aefbe0 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
       return Float128Type::get(&context);
     if (type->isX86_FP80Ty())
       return Float80Type::get(&context);
+    if (type->isX86_AMXTy())
+      return LLVM::LLVMX86AMXType::get(&context);
     if (type->isPPC_FP128Ty())
       return LLVM::LLVMPPCFP128Type::get(&context);
     if (type->isLabelTy())

diff  --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 65915027238801..c7a533eddce84b 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
             .Case([this](LLVM::LLVMMetadataType) {
               return llvm::Type::getMetadataTy(context);
             })
+            .Case([this](LLVM::LLVMX86AMXType) {
+              return llvm::Type::getX86_AMXTy(context);
+            })
             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,

diff  --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d0..8febe1605e33a4 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : vector<17x16xbf16>
+  %0 = amx.tile_zero : !amx.tile<17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : vector<16x65xi8>
+  %0 = amx.tile_zero : !amx.tile<16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : vector<16x5xi8>
+  %0 = amx.tile_zero : !amx.tile<16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : vector<8x8xbf16>
-  %1 = amx.tile_zero : vector<8x8xbf16>
-  %2 = amx.tile_zero : vector<4x4xf32>
+  %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %1 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %2 = amx.tile_zero : !amx.tile<4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
 }

diff  --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 3cacbd0044f825..8085f5f59fcaf0 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,33 +14,49 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, !amx.tile<16x16xi32>
   return
 }
 
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
 // CHECK: amx.tilezero
 // CHECK: amx.tileloadd64
 // CHECK: amx.tileloadd64
 // CHECK: amx.tdpbf16ps
 // CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_zero : !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_zero : !amx.tile<16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
@@ -63,11 +79,11 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
 func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
-  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
-  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
+  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
+  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
+  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
   return
 }

diff  --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c357..1b7f781ae173d3 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
   return
 }


        


More information about the Mlir-commits mailing list