[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)
Ilya Enkovich
llvmlistbot at llvm.org
Mon Oct 7 11:33:01 PDT 2024
https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/111197
>From 38c2803f072167a2c289c2c8dceff4c220ae3476 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 2 Oct 2024 22:21:32 +0000
Subject: [PATCH 1/2] Utilize x86_amx type for AMX dialect in MLIR.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../Conversion/LLVMCommon/TypeConverter.h | 4 +
mlir/include/mlir/Dialect/AMX/AMX.td | 139 +++++++++++++-----
mlir/include/mlir/Dialect/AMX/AMXDialect.h | 3 +
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 12 ++
mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 +
.../Conversion/LLVMCommon/TypeConverter.cpp | 7 +
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 63 ++++++--
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 51 ++++---
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 2 +
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 6 +-
mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp | 2 +
mlir/lib/Target/LLVMIR/TypeToLLVM.cpp | 3 +
mlir/test/Dialect/AMX/invalid.mlir | 18 +--
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 52 ++++---
mlir/test/Dialect/AMX/roundtrip.mlir | 50 +++----
15 files changed, 287 insertions(+), 126 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..bd4b3e73f07410 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
+ /// Convert AMX tile type x86_amx type.
+ Type convertAMXTileType(amx::TileType type) const;
+
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..8ef5ac25fbbddf 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
+ let useDefaultTypePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<AMX_Dialect, typeName, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+ let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+ let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+ let description = [{
+ This type is used to represent values in AMX tile registers. All AMX operations
+ work on AMX tiles and these tiles cannot be used in other operations directly.
+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+ element type for IR verification and lowering to LLVMIR dialect.
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ AMX_TileTypeElementType:$elementType
+ );
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+ return $_get(elementType.getContext(), shape, elementType);
+ }]>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Clone this tile type with the given shape and element type. If the
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return get(shape.value_or(getShape()), elementType);
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+ "::mlir::amx::TileType">;
+
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
- %0 = amx.tile_zero : vector<16x16xbf16>
+ %0 = amx.tile_zero : <16x16xbf16>
```
}];
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getVal().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
Example:
```mlir
- %0 = amx.tile_mulf %a, %b, %c
- : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$acc);
- let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+ let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+ AMX2DTileOf<[F16, BF16]>:$rhs,
+ AMX2DTileOf<[F32]>:$acc);
+ let results = (outs AMX2DTileOf<[F32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
Example:
```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c
- : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
- VectorOfRankAndType<[2], [I32, I8]>:$rhs,
- VectorOfRankAndType<[2], [I32, I8]>:$acc,
+ let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+ AMX2DTileOf<[I8]>:$rhs,
+ AMX2DTileOf<[I32]>:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
- let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+ let results = (outs AMX2DTileOf<[I32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+ Arguments<(ins AnyInteger,
+ AnyInteger,
+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea..c0553ad8733fd4 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b44..09dd0919c318fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
}];
}
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+ let summary = "LLVM x86_amx type.";
+ let description = [{
+ The x86_amx type represents a value held in an AMX tile register on an x86
+ machine. Can only be used in AMX intrinsics calls.
+ }];
+}
+
#endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb..39199e4affccfa 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
Core
LINK_LIBS PUBLIC
+ MLIRAMXDialect
MLIRIR
MLIRLLVMDialect
MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..a585a4f6ab76f6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return std::nullopt;
return llvmType;
});
+ addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -596,6 +597,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return vectorType;
}
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+ return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d..829f48e223383e 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
+
using namespace mlir;
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
void amx::AMXDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
}
/// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
}
/// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
- VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+ amx::TileType btp, amx::TileType ctp,
unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
}
LogicalResult amx::TileZeroOp::verify() {
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileLoadOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileStoreOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileMulFOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
- if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+ if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
LogicalResult amx::TileMulIOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
return success();
}
+Type amx::TileType::parse(AsmParser &parser) {
+ if (parser.parseLess())
+ return nullptr;
+
+ SmallVector<int64_t, 2> shape;
+ if (parser.parseDimensionList(shape, false, true))
+ return nullptr;
+
+ Type elementType;
+ if (parser.parseType(elementType))
+ return nullptr;
+
+ if (parser.parseGreater())
+ return nullptr;
+
+ return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+ os << "<";
+ os.printDimensionList(getShape());
+ os << 'x';
+ os.printType(getElementType());
+ os << '>';
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..415a0998f684f9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
/// The second dimensions needs to be scaled by the number of bytes.
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
- VectorType vType, Location loc) {
+ amx::TileType tType, Location loc) {
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
- unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+ unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
- auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+ auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+ auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return std::make_pair(
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -78,12 +78,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Replace operation with intrinsic.
- Type resType = typeConverter->convertType(vType);
+ Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
tsz.second);
return success();
@@ -97,10 +97,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
return failure();
@@ -109,7 +109,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- Type resType = typeConverter->convertType(vType);
+ Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride);
return success();
@@ -123,10 +123,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
return failure();
@@ -146,9 +146,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
+ amx::TileType aType = op.getLhsTileType();
+ amx::TileType bType = op.getRhsTileType();
+ amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -156,9 +156,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
- op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
- adaptor.getLhs(), adaptor.getRhs());
+ if (aType.getElementType().isBF16())
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+ adaptor.getLhs(), adaptor.getRhs());
+ else
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+ adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -168,9 +173,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
LogicalResult
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
+ amx::TileType aType = op.getLhsTileType();
+ amx::TileType bType = op.getRhsTileType();
+ amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -210,8 +215,8 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
- x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
- x86_amx_tdpbusd, x86_amx_tdpbuud>();
+ x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+ x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
TileMulFOp>();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38..9537f7c40dd4be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
+ .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
.Default([](Type) -> StringRef {
llvm_unreachable("unexpected 'llvm' type kind");
});
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("array", [&] { return LLVMArrayType::parse(parser); })
.Case("struct", [&] { return parseStructType(parser); })
.Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+ .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
return Type();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff9..1bed3fa48b30d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMTargetExtType,
- LLVMVoidType
+ LLVMVoidType,
+ LLVMX86AMXType
>(type)) {
// clang-format on
return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMMetadataType,
LLVMPPCFP128Type,
LLVMTokenType,
- LLVMVoidType
+ LLVMVoidType,
+ LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
.Default([](Type) { return false; });
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e833..ea990ca7aefbe0 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
return Float128Type::get(&context);
if (type->isX86_FP80Ty())
return Float80Type::get(&context);
+ if (type->isX86_AMXTy())
+ return LLVM::LLVMX86AMXType::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isLabelTy())
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 65915027238801..c7a533eddce84b 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
.Case([this](LLVM::LLVMMetadataType) {
return llvm::Type::getMetadataTy(context);
})
+ .Case([this](LLVM::LLVMX86AMXType) {
+ return llvm::Type::getX86_AMXTy(context);
+ })
.Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
LLVM::LLVMPointerType, LLVM::LLVMStructType,
LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d0..46f3b7448148b3 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
func.func @rowheight() {
// expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
- %0 = amx.tile_zero : vector<17x16xbf16>
+ %0 = amx.tile_zero : <17x16xbf16>
}
// -----
func.func @colwidth() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
- %0 = amx.tile_zero : vector<16x65xi8>
+ %0 = amx.tile_zero : <16x65xi8>
}
// -----
func.func @col4bytemultiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
- %0 = amx.tile_zero : vector<16x5xi8>
+ %0 = amx.tile_zero : <16x5xi8>
}
// -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
func.func @memtilesize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
}
// -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
func.func @memindexsize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
- %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+ %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
}
// -----
func.func @multsize() {
- %0 = amx.tile_zero : vector<8x8xbf16>
- %1 = amx.tile_zero : vector<8x8xbf16>
- %2 = amx.tile_zero : vector<4x4xf32>
+ %0 = amx.tile_zero : <8x8xbf16>
+ %1 = amx.tile_zero : <8x8xbf16>
+ %2 = amx.tile_zero : <4x4xf32>
// expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
- %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+ %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..685abbfcb830de 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,32 +14,48 @@
// CHECK: amx.tilestored64
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x64xi8>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
- %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
- %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
+ %1 = amx.tile_zero : <16x64xi8>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
+ %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, <16x16xi32>
return
}
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
// CHECK: amx.tilezero
// CHECK: amx.tileloadd64
// CHECK: amx.tileloadd64
// CHECK: amx.tdpbf16ps
// CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x32xbf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+ %1 = amx.tile_zero : <16x32xbf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = amx.tile_zero : <16x32xf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
return
}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c357..62bc071cab1e02 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : <16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x16xbf16>
- amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+ %1 = amx.tile_zero : <16x16xbf16>
+ amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
return
}
// CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
- %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
- amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
return
}
// CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
// Verify the various `zext` combinations.
- %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
return
}
>From bbaeadd5ca48f2da714f4af2adc84ba1ba553df6 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 7 Oct 2024 16:46:03 +0000
Subject: [PATCH 2/2] Fix review comments.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../Conversion/LLVMCommon/TypeConverter.h | 2 +-
mlir/include/mlir/Dialect/AMX/AMX.td | 67 +++++++++++--------
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 4 +-
mlir/test/Dialect/AMX/invalid.mlir | 18 ++---
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 42 ++++++------
mlir/test/Dialect/AMX/roundtrip.mlir | 50 +++++++-------
6 files changed, 97 insertions(+), 86 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index bd4b3e73f07410..6e01f70ee35123 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -259,7 +259,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
- /// Convert AMX tile type x86_amx type.
+ /// Convert an AMX tile type to the x86_amx type.
Type convertAMXTileType(amx::TileType type) const;
/// Options for customizing the llvm lowering.
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8ef5ac25fbbddf..8a51df1ea183fc 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}
-def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";
let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
let skipDefaultBuilders = 1;
}
-def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
-
-def IsAMX2DTilePred : And<[IsAMXTilePred,
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-class AMX2DTileOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+class AMXTileOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::amx::TileType">;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
- %0 = amx.tile_zero : <16x16xbf16>
+ %0 = amx.tile_zero : !amx.tile<16x16xbf16>
```
}];
- let results = (outs
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+ let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
- let assemblyFormat = "attr-dict `:` type($res)";
+ let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}
@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
- let results = (outs
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+ let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
- "type($base) `into` type($res)";
+ "type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
+ AnyAMXTile:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
- "type($base) `,` type($val)";
+ "type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
Example:
```mlir
- %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ %0 = amx.tile_mulf %a, %b, %c
+ : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
```
}];
- let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
- AMX2DTileOf<[F16, BF16]>:$rhs,
- AMX2DTileOf<[F32]>:$acc);
- let results = (outs AMX2DTileOf<[F32]>:$res);
+ let arguments = (ins AMXTileF16OrBF16:$lhs,
+ AMXTileF16OrBF16:$rhs,
+ AMXTileF32:$acc);
+ let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
- "type($lhs) `,` type($rhs) `,` type($acc) ";
+ "qualified(type($lhs)) `,` qualified(type($rhs))"
+ " `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
Example:
```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %0 = amx.tile_muli %a zext, %b zext, %c
+ : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
```
}];
- let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
- AMX2DTileOf<[I8]>:$rhs,
- AMX2DTileOf<[I32]>:$acc,
+ let arguments = (ins AMXTileI8:$lhs,
+ AMXTileI8:$rhs,
+ AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
- let results = (outs AMX2DTileOf<[I32]>:$res);
+ let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
- "type($lhs) `,` type($rhs) `,` type($acc) ";
+ "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 415a0998f684f9..cbca72d5f26fb5 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -160,10 +160,12 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
- else
+ else if (aType.getElementType().isF16())
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
+ else
+ llvm_unreachable("Unexpected element type for amx.mulf");
return success();
}
};
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 46f3b7448148b3..8febe1605e33a4 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
func.func @rowheight() {
// expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
- %0 = amx.tile_zero : <17x16xbf16>
+ %0 = amx.tile_zero : !amx.tile<17x16xbf16>
}
// -----
func.func @colwidth() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
- %0 = amx.tile_zero : <16x65xi8>
+ %0 = amx.tile_zero : !amx.tile<16x65xi8>
}
// -----
func.func @col4bytemultiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
- %0 = amx.tile_zero : <16x5xi8>
+ %0 = amx.tile_zero : !amx.tile<16x5xi8>
}
// -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
func.func @memtilesize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
}
// -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
func.func @memindexsize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
- %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
+ %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
}
// -----
func.func @multsize() {
- %0 = amx.tile_zero : <8x8xbf16>
- %1 = amx.tile_zero : <8x8xbf16>
- %2 = amx.tile_zero : <4x4xf32>
+ %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+ %1 = amx.tile_zero : !amx.tile<8x8xbf16>
+ %2 = amx.tile_zero : !amx.tile<4x4xf32>
// expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
- %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
+ %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 685abbfcb830de..fe0f4fec5f313e 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,17 +14,17 @@
// CHECK: amx.tilestored64
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x64xi8>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
- %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
- %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, <16x16xi32>
+ %1 = amx.tile_zero : !amx.tile<16x64xi8>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
return
}
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
// CHECK: amx.tilestored64
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x32xbf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_zero : !amx.tile<16x32xbf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
@@ -52,10 +52,10 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: amx.tilestored64
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x32xf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_zero : !amx.tile<16x32xf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 62bc071cab1e02..1b7f781ae173d3 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : <16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
+// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x16xbf16>
- amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
+ %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+ amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
return
}
// CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the various `zext` combinations.
- %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
return
}
More information about the Mlir-commits
mailing list