[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)
    Ilya Enkovich 
    llvmlistbot at llvm.org
       
    Mon Oct  7 11:33:01 PDT 2024
    
    
  
https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/111197
>From 38c2803f072167a2c289c2c8dceff4c220ae3476 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 2 Oct 2024 22:21:32 +0000
Subject: [PATCH 1/2] Utilize x86_amx type for AMX dialect in MLIR.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |   4 +
 mlir/include/mlir/Dialect/AMX/AMX.td          | 139 +++++++++++++-----
 mlir/include/mlir/Dialect/AMX/AMXDialect.h    |   3 +
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td |  12 ++
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |   1 +
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   7 +
 mlir/lib/Dialect/AMX/IR/AMXDialect.cpp        |  63 ++++++--
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  51 ++++---
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp |   2 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |   6 +-
 mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp       |   2 +
 mlir/lib/Target/LLVMIR/TypeToLLVM.cpp         |   3 +
 mlir/test/Dialect/AMX/invalid.mlir            |  18 +--
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  |  52 ++++---
 mlir/test/Dialect/AMX/roundtrip.mlir          |  50 +++----
 15 files changed, 287 insertions(+), 126 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..bd4b3e73f07410 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
+  /// Convert AMX tile type x86_amx type.
+  Type convertAMXTileType(amx::TileType type) const;
+
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
 
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..8ef5ac25fbbddf 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 
 //===----------------------------------------------------------------------===//
 // AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
     For details, see the Intel documentation:
     https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
   }];
+  let useDefaultTypePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<AMX_Dialect, typeName, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+  let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+  let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+  let description = [{
+    This type is used to represent values in AMX tile registers. All AMX operations
+    work on AMX tiles and these tiles cannot be used in other operations directly.
+    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+    element type for IR verification and lowering to LLVMIR dialect.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    AMX_TileTypeElementType:$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this tile type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+      return get(shape.value_or(getShape()), elementType);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+                      "::mlir::amx::TileType">;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : vector<16x16xbf16>
+      %0 = amx.tile_zero : <16x16xbf16>
     ```
   }];
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
   let results = (outs
-    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+                   AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getVal().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getVal().getType());
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     Example:
 
     ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+      %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
-  let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+  let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+                       AMX2DTileOf<[F16, BF16]>:$rhs,
+                       AMX2DTileOf<[F32]>:$acc);
+  let results = (outs AMX2DTileOf<[F32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+      %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
     ```
   }];
-  let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$rhs,
-                       VectorOfRankAndType<[2], [I32, I8]>:$acc,
+  let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+                       AMX2DTileOf<[I8]>:$rhs,
+                       AMX2DTileOf<[I32]>:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+  let results = (outs AMX2DTileOf<[I32]>:$res);
   let extraClassDeclaration = [{
-    VectorType getLhsVectorType() {
-      return ::llvm::cast<VectorType>(getLhs().getType());
+    TileType getLhsTileType() {
+      return ::llvm::cast<TileType>(getLhs().getType());
     }
-    VectorType getRhsVectorType() {
-      return ::llvm::cast<VectorType>(getRhs().getType());
+    TileType getRhsTileType() {
+      return ::llvm::cast<TileType>(getRhs().getType());
     }
-    VectorType getVectorType() {
-      return ::llvm::cast<VectorType>(getRes().getType());
+    TileType getTileType() {
+      return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
                  AnyInteger,
 		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+  Arguments<(ins AnyInteger,
+                 AnyInteger,
+		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 // Dot product of i8 tiles into i32 tile (with sign/sign extension).
 def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
   Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea..c0553ad8733fd4 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b44..09dd0919c318fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+  let summary = "LLVM x86_amx type.";
+  let description = [{
+    The x86_amx type represents a value held in an AMX tile register on an x86
+    machine. Can only be used in AMX intrinsics calls.
+  }];
+}
+
 #endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb..39199e4affccfa 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAMXDialect
   MLIRIR
   MLIRLLVMDialect
   MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..a585a4f6ab76f6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       return std::nullopt;
     return llvmType;
   });
+  addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -596,6 +597,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   return vectorType;
 }
 
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+  return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
 /// Convert a type in the context of the default or bare pointer calling
 /// convention. Calling convention sensitive types, such as MemRefType and
 /// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d..829f48e223383e 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
 
 void amx::AMXDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
 }
 
 /// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
   const unsigned kMaxRows = 16;
   const unsigned kBitsPerRow = 64 * 8;
   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
 }
 
 /// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
-                                     VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+                                     amx::TileType btp, amx::TileType ctp,
                                      unsigned scale) {
   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
 }
 
 LogicalResult amx::TileZeroOp::verify() {
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileLoadOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileStoreOp::verify() {
   unsigned rank = getMemRefType().getRank();
   if (getIndices().size() != rank)
     return emitOpError("requires ") << rank << " indices";
-  return verifyTileSize(*this, getVectorType());
+  return verifyTileSize(*this, getTileType());
 }
 
 LogicalResult amx::TileMulFOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
 
 LogicalResult amx::TileMulIOp::verify() {
-  VectorType aType = getLhsVectorType();
-  VectorType bType = getRhsVectorType();
-  VectorType cType = getVectorType();
+  amx::TileType aType = getLhsTileType();
+  amx::TileType bType = getRhsTileType();
+  amx::TileType cType = getTileType();
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
   return success();
 }
 
+Type amx::TileType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return nullptr;
+
+  SmallVector<int64_t, 2> shape;
+  if (parser.parseDimensionList(shape, false, true))
+    return nullptr;
+
+  Type elementType;
+  if (parser.parseType(elementType))
+    return nullptr;
+
+  if (parser.parseGreater())
+    return nullptr;
+
+  return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+  os << "<";
+  os.printDimensionList(getShape());
+  os << 'x';
+  os.printType(getElementType());
+  os << '>';
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..415a0998f684f9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
 /// The second dimensions needs to be scaled by the number of bytes.
 std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &typeConverter,
-                                     VectorType vType, Location loc) {
+                                     amx::TileType tType, Location loc) {
   Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
-  unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
-  auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
   return std::make_pair(
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -78,12 +78,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
                                                        tsz.second);
     return success();
@@ -97,10 +97,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
@@ -109,7 +109,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Type resType = typeConverter->convertType(vType);
+    Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride);
     return success();
@@ -123,10 +123,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
+    amx::TileType tType = op.getTileType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+        getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
@@ -146,9 +146,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -156,9 +156,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-        op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-        adaptor.getLhs(), adaptor.getRhs());
+    if (aType.getElementType().isBF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
     return success();
   }
 };
@@ -168,9 +173,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -210,8 +215,8 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
-                    x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
-                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+                    x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+                    x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38..9537f7c40dd4be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
       .Case<LLVMArrayType>([&](Type) { return "array"; })
       .Case<LLVMStructType>([&](Type) { return "struct"; })
       .Case<LLVMTargetExtType>([&](Type) { return "target"; })
+      .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
       .Default([](Type) -> StringRef {
         llvm_unreachable("unexpected 'llvm' type kind");
       });
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })
       .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+      .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff9..1bed3fa48b30d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
       LLVMFixedVectorType,
       LLVMScalableVectorType,
       LLVMTargetExtType,
-      LLVMVoidType
+      LLVMVoidType,
+      LLVMX86AMXType
     >(type)) {
     // clang-format on
     return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
             LLVMMetadataType,
             LLVMPPCFP128Type,
             LLVMTokenType,
-            LLVMVoidType
+            LLVMVoidType,
+            LLVMX86AMXType
           >([](Type) { return true; })
           // clang-format on
           .Default([](Type) { return false; });
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e833..ea990ca7aefbe0 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
       return Float128Type::get(&context);
     if (type->isX86_FP80Ty())
       return Float80Type::get(&context);
+    if (type->isX86_AMXTy())
+      return LLVM::LLVMX86AMXType::get(&context);
     if (type->isPPC_FP128Ty())
       return LLVM::LLVMPPCFP128Type::get(&context);
     if (type->isLabelTy())
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 65915027238801..c7a533eddce84b 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
             .Case([this](LLVM::LLVMMetadataType) {
               return llvm::Type::getMetadataTy(context);
             })
+            .Case([this](LLVM::LLVMX86AMXType) {
+              return llvm::Type::getX86_AMXTy(context);
+            })
             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d0..46f3b7448148b3 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : vector<17x16xbf16>
+  %0 = amx.tile_zero : <17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : vector<16x65xi8>
+  %0 = amx.tile_zero : <16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : vector<16x5xi8>
+  %0 = amx.tile_zero : <16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : vector<8x8xbf16>
-  %1 = amx.tile_zero : vector<8x8xbf16>
-  %2 = amx.tile_zero : vector<4x4xf32>
+  %0 = amx.tile_zero : <8x8xbf16>
+  %1 = amx.tile_zero : <8x8xbf16>
+  %2 = amx.tile_zero : <4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
 }
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..685abbfcb830de 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,32 +14,48 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_zero : <16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, <16x16xi32>
   return
 }
 
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
 // CHECK: amx.tilezero
 // CHECK: amx.tileloadd64
 // CHECK: amx.tileloadd64
 // CHECK: amx.tdpbf16ps
 // CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_zero : <16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_zero : <16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
   return
 }
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c357..62bc071cab1e02 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : <16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : vector<16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+  %1 = amx.tile_zero : <16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
   return
 }
>From bbaeadd5ca48f2da714f4af2adc84ba1ba553df6 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 7 Oct 2024 16:46:03 +0000
Subject: [PATCH 2/2] Fix review comments.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  2 +-
 mlir/include/mlir/Dialect/AMX/AMX.td          | 67 +++++++++++--------
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  4 +-
 mlir/test/Dialect/AMX/invalid.mlir            | 18 ++---
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  | 42 ++++++------
 mlir/test/Dialect/AMX/roundtrip.mlir          | 50 +++++++-------
 6 files changed, 97 insertions(+), 86 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index bd4b3e73f07410..6e01f70ee35123 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -259,7 +259,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a 1D vector type into an LLVM vector type.
   FailureOr<Type> convertVectorType(VectorType type) const;
 
-  /// Convert AMX tile type x86_amx type.
+  /// Convert an AMX tile type to the x86_amx type.
   Type convertAMXTileType(amx::TileType type) const;
 
   /// Options for customizing the llvm lowering.
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8ef5ac25fbbddf..8a51df1ea183fc 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
   let cppFunctionName = "isValidTileTypeElementType";
 }
 
-def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
   let summary = "AMX 2D tile to be used by AMX opertaions.";
 
   let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
   let skipDefaultBuilders = 1;
 }
 
-def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
-
-def IsAMX2DTilePred : And<[IsAMXTilePred,
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
   CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
 
-class AMX2DTileOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+class AMXTileOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
                       "::mlir::amx::TileType">;
 
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
 //===----------------------------------------------------------------------===//
 // AMX Op and IntrOp definitions.
 //===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_zero : <16x16xbf16>
+      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
     ```
   }];
-  let results = (outs
-    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
     TileType getTileType() {
       return ::llvm::cast<TileType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "attr-dict `:` type($res)";
+  let assemblyFormat = "attr-dict `:` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     Example:
 
     ```mlir
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
+      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
                    Variadic<Index>:$indices);
-  let results = (outs
-    AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+  let results = (outs AnyAMXTile:$res);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
-                       "type($base) `into` type($res)";
+                       "type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     Example:
 
     ```mlir
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
+      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
+                   AnyAMXTile:$val);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
-                       "type($base) `,` type($val)";
+                       "type($base) `,` qualified(type($val))";
   let hasVerifier = 1;
 }
 
@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     Example:
 
     ```mlir
-      %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+      %0 = amx.tile_mulf %a, %b, %c
+        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
     ```
   }];
-  let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
-                       AMX2DTileOf<[F16, BF16]>:$rhs,
-                       AMX2DTileOf<[F32]>:$acc);
-  let results = (outs AMX2DTileOf<[F32]>:$res);
+  let arguments = (ins AMXTileF16OrBF16:$lhs,
+                       AMXTileF16OrBF16:$rhs,
+                       AMXTileF32:$acc);
+  let results = (outs AMXTileF32:$res);
   let extraClassDeclaration = [{
     TileType getLhsTileType() {
       return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
     }
   }];
   let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs))"
+                       " `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
+      %0 = amx.tile_muli %a zext, %b zext, %c
+        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
     ```
   }];
-  let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
-                       AMX2DTileOf<[I8]>:$rhs,
-                       AMX2DTileOf<[I32]>:$acc,
+  let arguments = (ins AMXTileI8:$lhs,
+                       AMXTileI8:$rhs,
+                       AMXTileI32:$acc,
                        UnitAttr:$isZextLhs,
                        UnitAttr:$isZextRhs
                        );
-  let results = (outs AMX2DTileOf<[I32]>:$res);
+  let results = (outs AMXTileI32:$res);
   let extraClassDeclaration = [{
     TileType getLhsTileType() {
       return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
     }
   }];
   let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
-                       "type($lhs) `,` type($rhs) `,` type($acc) ";
+                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 415a0998f684f9..cbca72d5f26fb5 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -160,10 +160,12 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
           adaptor.getLhs(), adaptor.getRhs());
-    else
+    else if (aType.getElementType().isF16())
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
           adaptor.getLhs(), adaptor.getRhs());
+    else
+      llvm_unreachable("Unexpected element type for amx.mulf");
     return success();
   }
 };
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 46f3b7448148b3..8febe1605e33a4 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
 
 func.func @rowheight() {
   // expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
-  %0 = amx.tile_zero : <17x16xbf16>
+  %0 = amx.tile_zero : !amx.tile<17x16xbf16>
 }
 
 // -----
 
 func.func @colwidth() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
-  %0 = amx.tile_zero : <16x65xi8>
+  %0 = amx.tile_zero : !amx.tile<16x65xi8>
 }
 
 // -----
 
 func.func @col4bytemultiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
-  %0 = amx.tile_zero : <16x5xi8>
+  %0 = amx.tile_zero : !amx.tile<16x5xi8>
 }
 
 // -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
 func.func @memtilesize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
 }
 
 // -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
 func.func @memindexsize(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
-  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
+  %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
 }
 
 // -----
 
 func.func @multsize() {
-  %0 = amx.tile_zero : <8x8xbf16>
-  %1 = amx.tile_zero : <8x8xbf16>
-  %2 = amx.tile_zero : <4x4xf32>
+  %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %1 = amx.tile_zero : !amx.tile<8x8xbf16>
+  %2 = amx.tile_zero : !amx.tile<4x4xf32>
   // expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
-  %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
+  %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
 }
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 685abbfcb830de..fe0f4fec5f313e 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,17 +14,17 @@
 // CHECK: amx.tilestored64
 func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x64xi8>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
-  %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, <16x16xi32>
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, !amx.tile<16x16xi32>
   return
 }
 
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
 // CHECK: amx.tilestored64
 func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x32xbf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_zero : !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
@@ -52,10 +52,10 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: amx.tilestored64
 func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x32xf16>
-  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
-  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_zero : !amx.tile<16x32xf16>
+  %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
+  %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 62bc071cab1e02..1b7f781ae173d3 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : <16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
+// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
 func.func @tzero(%arg0: memref<?x?xbf16>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_zero : <16x16xbf16>
-  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
   return
 }
 
 // CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
 func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
-  %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+  %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+  amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
   return
 }
 
 // CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
 // Verify the parsing/printing of the sign-extension annotation.
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
 // CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
-  %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
   // Verify the various `zext` combinations.
-  %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+  %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
   return
 }
    
    
More information about the Mlir-commits
mailing list