[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)

Ilya Enkovich llvmlistbot at llvm.org
Fri Oct 4 11:52:27 PDT 2024


https://github.com/ienkovich created https://github.com/llvm/llvm-project/pull/111197

This patch is intended to resolve #109481 and improve the usability of the AMX dialect.

In LLVM IR, AMX intrinsics use `x86_amx` which is one of the primitive types. This type is supposed to be used for AMX intrinsic calls and no other operations. AMX dialect of MLIR uses regular 2D vector types, which are then lowered to arrays of vectors in the LLVMIR dialect. This creates an inconsistency in the types used in the LLVMIR dialect and LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require result types to match and that is where tile loads and mul operation results get `x86_amx` type. This works in very simple cases when mul and tile store operations directly consume the result of another AMX intrinsic call, but it doesn't work when an argument is a block argument (phi node).

In addition to translation problems, this inconsistency between types used in MLIR and LLVM IR makes MLIR verification and transformation quite problematic. Both `amx.tileload` and `vector::transfer_read` can load values of the same type, but only one of them can be used in AMX operations. In general, by looking at a type of value, we cannot determine if it can only be used for AMX operations or contrary can be used in other operations but AMX ones.

To remove this inconsistency and make AMX operations more explicit in their limitations, I propose to add `LLVMX86AMXType` type to the LLVMIR dialect to match `x86_amx` type in LLVM IR, and introduce `amx::TileType` to be used by AMX operations in MLIR. This resolves translation problems for AMX usage with phi nodes and provides proper type verification in MLIR for AMX operations.

P.S. This patch also adds missing FP16 support. It's trivial but unrelated to type system changes, so let me know if I should submit it separately.

>From 38c2803f072167a2c289c2c8dceff4c220ae3476 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 2 Oct 2024 22:21:32 +0000
Subject: [PATCH] Utilize x86_amx type for AMX dialect in MLIR.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |   4 +
 mlir/include/mlir/Dialect/AMX/AMX.td          | 139 +++++++++++++-----
 mlir/include/mlir/Dialect/AMX/AMXDialect.h    |   3 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td |  12 ++
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |   1 +
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   7 +
 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp        |  63 ++++++--
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  51 ++++---
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp |   2 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |   6 +-
 mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp       |   2 +
 mlir/lib/Target/LLVMIR/TypeToLLVM.cpp         |   3 +
 mlir/test/Dialect/AMX/invalid.mlir            |  18 +--
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  |  52 ++++---
 mlir/test/Dialect/AMX/roundtrip.mlir          |  50 +++----
 15 files changed, 287 insertions(+), 126 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..bd4b3e73f07410 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
+  /// Convert AMX tile type x86_amx type.
+  Type convertAMXTileType(amx::TileType type) const;
+
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
 
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..8ef5ac25fbbddf 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 
 //===----------------------------------------------------------------------===//
 // AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
     For details, see the Intel documentation:
     https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
   }];
+  let useDefaultTypePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<AMX_Dialect, typeName, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+  let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+  let description = [{
+    This type is used to represent values in AMX tile registers. All AMX operations
+    work on AMX tiles and these tiles cannot be used in other operations directly.
+    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+    element type for IR verification and lowering to LLVMIR dialect.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    AMX_TileTypeElementType:$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this tile type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+      return get(shape.value_or(getShape()), elementType);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+                      "::mlir::amx::TileType">;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : vector<16x16xbf16>
+      %0 = amx.tile_zero : <16x16xbf16>
     ```
   }];
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+                   AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getVal().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getVal().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     Example:
 
     ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+      %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
-  let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+  let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+                       AMX2DTileOf<[F16, BF16]>:$rhs,
+                       AMX2DTileOf<[F32]>:$acc);
+  let results = (outs AMX2DTileOf<[F32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+      %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$rhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$acc,
+  let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+                       AMX2DTileOf<[I8]>:$rhs,
+                       AMX2DTileOf<[I32]>:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+  let results = (outs AMX2DTileOf<[I32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
                  AnyInteger,
 		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 // Dot product of i8 tiles into i32 tile (with sign/sign extension).
 def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
   Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea..c0553ad8733fd4 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b44..09dd0919c318fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+  let summary = "LLVM x86_amx type.";
+  let description = [{
+    The x86_amx type represents a value held in an AMX tile register on an x86
+    machine. Can only be used in AMX intrinsics calls.
+  }];
+}
+
 #endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb..39199e4affccfa 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAMXDialect
   MLIRIR
   MLIRLLVMDialect
   MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..a585a4f6ab76f6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       return std::nullopt;
     return llvmType;
   });
+  addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -596,6 +597,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   return vectorType;
 }
 
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+  return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
 /// Convert a type in the context of the default or bare pointer calling
 /// convention. Calling convention sensitive types, such as MemRefType and
 /// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d..829f48e223383e 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
 }
 
 /// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
   const unsigned kMaxRows = 16;
   const unsigned kBitsPerRow = 64 * 8;
   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
 }
 
 /// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
-                                     VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+                                     amx::TileType btp, amx::TileType ctp,
                                      unsigned scale) {
   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
 }
 
 LogicalResult amx::TileZeroOp::verify() {
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileLoadOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileStoreOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileMulFOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
 
 LogicalResult amx::TileMulIOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
   return success();
 }
 
+Type amx::TileType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return nullptr;
+
+  SmallVector<int64_t, 2> shape;
+  if (parser.parseDimensionList(shape, false, true))
+    return nullptr;
+
+  Type elementType;
+  if (parser.parseType(elementType))
+    return nullptr;
+
+  if (parser.parseGreater())
+    return nullptr;
+
+  return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+  os << "<";
+  os.printDimensionList(getShape());
+  os << 'x';
+  os.printType(getElementType());
+  os << '>';
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..415a0998f684f9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
 /// The second dimensions needs to be scaled by the number of bytes.
 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &typeConverter,
-                                     VectorType vType, Location loc) {
+                                     amx::TileType tType, Location loc) {
   Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
-  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
   return std::make_pair(
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -78,12 +78,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
                                                        tsz.second);
     return success();
@@ -97,10 +97,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
@@ -109,7 +109,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride);
     return success();
@@ -123,10 +123,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
@@ -146,9 +146,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -156,9 +156,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-        op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-        adaptor.getLhs(), adaptor.getRhs());
+    if (aType.getElementType().isBF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
     return success();
   }
 };
@@ -168,9 +173,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -210,8 +215,8 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
-                    x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
-                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+                    x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+                    x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38..9537f7c40dd4be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
       .Case<LLVMArrayType>([&](Type) { return "array"; })
       .Case<LLVMStructType>([&](Type) { return "struct"; })
       .Case<LLVMTargetExtType>([&](Type) { return "target"; })
+      .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
       .Default([](Type) -> StringRef {
         llvm_unreachable("unexpected 'llvm' type kind");
       });
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })
       .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+      .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff9..1bed3fa48b30d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
       LLVMFixedVectorType,
       LLVMScalableVectorType,
       LLVMTargetExtType,
-      LLVMVoidType
+      LLVMVoidType,
+      LLVMX86AMXType
     >(type)) {
     // clang-format on
     return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
             LLVMMetadataType,
             LLVMPPCFP128Type,
             LLVMTokenType,
-            LLVMVoidType
+            LLVMVoidType,
+            LLVMX86AMXType
           >([](Type) { return true; })
           // clang-format on
           .Default([](Type) { return false; });
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e833..ea990ca7aefbe0 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
       return Float128Type::get(&context);
     if (type->isX86_FP80Ty())
       return Float80Type::get(&context);
+    if (type->isX86_AMXTy())
+      return LLVM::LLVMX86AMXType::get(&context);
     if (type->isPPC_FP128Ty())
       return LLVM::LLVMPPCFP128Type::get(&context);
     if (type->isLabelTy())
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 65915027238801..c7a533eddce84b 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
             .Case([this](LLVM::LLVMMetadataType) {
               return llvm::Type::getMetadataTy(context);
             })
+            .Case([this](LLVM::LLVMX86AMXType) {
+              return llvm::Type::getX86_AMXTy(context);
+            })
             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d0..46f3b7448148b3 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : vector<17x16xbf16>
+  %0 = amx.tile_zero : <17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : vector<16x65xi8>
+  %0 = amx.tile_zero : <16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : vector<16x5xi8>
+  %0 = amx.tile_zero : <16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : vector<8x8xbf16>
-  %1 = amx.tile_zero : vector<8x8xbf16>
-  %2 = amx.tile_zero : vector<4x4xf32>
+  %0 = amx.tile_zero : <8x8xbf16>
+  %1 = amx.tile_zero : <8x8xbf16>
+  %2 = amx.tile_zero : <4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
 }
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..685abbfcb830de 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,32 +14,48 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_zero : <16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, <16x16xi32>
   return
 }
 
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
 // CHECK: amx.tilezero
 // CHECK: amx.tileloadd64
 // CHECK: amx.tileloadd64
 // CHECK: amx.tdpbf16ps
 // CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_zero : <16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_zero : <16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
   return
 }
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c357..62bc071cab1e02 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : <16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+  %1 = amx.tile_zero : <16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
   return
 }



More information about the Mlir-commits mailing list