[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)
Ilya Enkovich
llvmlistbot at llvm.org
Mon Nov 4 12:13:47 PST 2024
https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/111197
>From 2a76b2d63ae39b5714ba927c3ae3a71790f47af9 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 2 Oct 2024 22:21:32 +0000
Subject: [PATCH 1/4] Utilize x86_amx type for AMX dialect in MLIR.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../Conversion/LLVMCommon/TypeConverter.h | 4 +
mlir/include/mlir/Dialect/AMX/AMX.td | 139 +++++++++++++-----
mlir/include/mlir/Dialect/AMX/AMXDialect.h | 3 +
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 12 ++
mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 +
.../Conversion/LLVMCommon/TypeConverter.cpp | 7 +
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 63 ++++++--
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 51 ++++---
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 2 +
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 6 +-
mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp | 2 +
mlir/lib/Target/LLVMIR/TypeToLLVM.cpp | 3 +
mlir/test/Dialect/AMX/invalid.mlir | 18 +--
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 52 ++++---
mlir/test/Dialect/AMX/roundtrip.mlir | 50 +++----
15 files changed, 287 insertions(+), 126 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce836..bd4b3e73f074108 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
+ /// Convert AMX tile type x86_amx type.
+ Type convertAMXTileType(amx::TileType type) const;
+
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac0..8ef5ac25fbbddff 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
+ let useDefaultTypePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<AMX_Dialect, typeName, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+ let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+ let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+ let description = [{
+ This type is used to represent values in AMX tile registers. All AMX operations
+ work on AMX tiles and these tiles cannot be used in other operations directly.
+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+ element type for IR verification and lowering to LLVMIR dialect.
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ AMX_TileTypeElementType:$elementType
+ );
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+ return $_get(elementType.getContext(), shape, elementType);
+ }]>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Clone this tile type with the given shape and element type. If the
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return get(shape.value_or(getShape()), elementType);
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+ "::mlir::amx::TileType">;
+
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
- %0 = amx.tile_zero : vector<16x16xbf16>
+ %0 = amx.tile_zero : <16x16xbf16>
```
}];
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getVal().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
Example:
```mlir
- %0 = amx.tile_mulf %a, %b, %c
- : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$acc);
- let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+ let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+ AMX2DTileOf<[F16, BF16]>:$rhs,
+ AMX2DTileOf<[F32]>:$acc);
+ let results = (outs AMX2DTileOf<[F32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
Example:
```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c
- : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
- VectorOfRankAndType<[2], [I32, I8]>:$rhs,
- VectorOfRankAndType<[2], [I32, I8]>:$acc,
+ let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+ AMX2DTileOf<[I8]>:$rhs,
+ AMX2DTileOf<[I32]>:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
- let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+ let results = (outs AMX2DTileOf<[I32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+ Arguments<(ins AnyInteger,
+ AnyInteger,
+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea8..c0553ad8733fd45 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b441..09dd0919c318fb1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
}];
}
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+ let summary = "LLVM x86_amx type.";
+ let description = [{
+ The x86_amx type represents a value held in an AMX tile register on an x86
+ machine. Can only be used in AMX intrinsics calls.
+ }];
+}
+
#endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb4..39199e4affccfac 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
Core
LINK_LIBS PUBLIC
+ MLIRAMXDialect
MLIRIR
MLIRLLVMDialect
MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 4e7758bf46d9cfc..07e6dd55f855191 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return std::nullopt;
return llvmType;
});
+ addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -594,6 +595,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return vectorType;
}
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+ return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d0..829f48e223383e9 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
+
using namespace mlir;
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
void amx::AMXDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
}
/// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
}
/// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
- VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+ amx::TileType btp, amx::TileType ctp,
unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
}
LogicalResult amx::TileZeroOp::verify() {
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileLoadOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileStoreOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileMulFOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
- if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+ if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
LogicalResult amx::TileMulIOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
return success();
}
+Type amx::TileType::parse(AsmParser &parser) {
+ if (parser.parseLess())
+ return nullptr;
+
+ SmallVector<int64_t, 2> shape;
+ if (parser.parseDimensionList(shape, false, true))
+ return nullptr;
+
+ Type elementType;
+ if (parser.parseType(elementType))
+ return nullptr;
+
+ if (parser.parseGreater())
+ return nullptr;
+
+ return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+ os << "<";
+ os.printDimensionList(getShape());
+ os << 'x';
+ os.printType(getElementType());
+ os << '>';
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 46c7bfbf3ffcc2e..3172d93ef65bc99 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
/// The second dimensions needs to be scaled by the number of bytes.
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
- VectorType vType, Location loc) {
+ amx::TileType tType, Location loc) {
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
- unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+ unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
- auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+ auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+ auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return std::make_pair(
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -76,12 +76,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Replace operation with intrinsic.
- Type resType = typeConverter->convertType(vType);
+ Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
tsz.second);
return success();
@@ -95,10 +95,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
@@ -107,7 +107,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- Type resType = typeConverter->convertType(vType);
+ Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride.value());
return success();
@@ -121,10 +121,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Determine stride.
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
@@ -144,9 +144,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
+ amx::TileType aType = op.getLhsTileType();
+ amx::TileType bType = op.getRhsTileType();
+ amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -154,9 +154,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
- op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
- adaptor.getLhs(), adaptor.getRhs());
+ if (aType.getElementType().isBF16())
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+ adaptor.getLhs(), adaptor.getRhs());
+ else
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+ adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
@@ -166,9 +171,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
LogicalResult
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
+ amx::TileType aType = op.getLhsTileType();
+ amx::TileType bType = op.getRhsTileType();
+ amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
@@ -208,8 +213,8 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
- x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
- x86_amx_tdpbusd, x86_amx_tdpbuud>();
+ x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
+ x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
TileMulFOp>();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index c4708d826f2b38b..9537f7c40dd4bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
+ .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
.Default([](Type) -> StringRef {
llvm_unreachable("unexpected 'llvm' type kind");
});
@@ -317,6 +318,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("array", [&] { return LLVMArrayType::parse(parser); })
.Case("struct", [&] { return parseStructType(parser); })
.Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+ .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
return Type();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7f10a15ff31ff94..1bed3fa48b30d7a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -780,7 +780,8 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMTargetExtType,
- LLVMVoidType
+ LLVMVoidType,
+ LLVMX86AMXType
>(type)) {
// clang-format on
return true;
@@ -842,7 +843,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMMetadataType,
LLVMPPCFP128Type,
LLVMTokenType,
- LLVMVoidType
+ LLVMVoidType,
+ LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
.Default([](Type) { return false; });
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index db184ae8e6e8331..ea990ca7aefbe03 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -63,6 +63,8 @@ class TypeFromLLVMIRTranslatorImpl {
return Float128Type::get(&context);
if (type->isX86_FP80Ty())
return Float80Type::get(&context);
+ if (type->isX86_AMXTy())
+ return LLVM::LLVMX86AMXType::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isLabelTy())
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 659150272388010..c7a533eddce84be 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -67,6 +67,9 @@ class TypeToLLVMIRTranslatorImpl {
.Case([this](LLVM::LLVMMetadataType) {
return llvm::Type::getMetadataTy(context);
})
+ .Case([this](LLVM::LLVMX86AMXType) {
+ return llvm::Type::getX86_AMXTy(context);
+ })
.Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
LLVM::LLVMPointerType, LLVM::LLVMStructType,
LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 25d57353c905d06..46f3b7448148b36 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
func.func @rowheight() {
// expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
- %0 = amx.tile_zero : vector<17x16xbf16>
+ %0 = amx.tile_zero : <17x16xbf16>
}
// -----
func.func @colwidth() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
- %0 = amx.tile_zero : vector<16x65xi8>
+ %0 = amx.tile_zero : <16x65xi8>
}
// -----
func.func @col4bytemultiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
- %0 = amx.tile_zero : vector<16x5xi8>
+ %0 = amx.tile_zero : <16x5xi8>
}
// -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
func.func @memtilesize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into vector<16x17xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
}
// -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
func.func @memindexsize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
- %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into vector<16x16xf32>
+ %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
}
// -----
func.func @multsize() {
- %0 = amx.tile_zero : vector<8x8xbf16>
- %1 = amx.tile_zero : vector<8x8xbf16>
- %2 = amx.tile_zero : vector<4x4xf32>
+ %0 = amx.tile_zero : <8x8xbf16>
+ %1 = amx.tile_zero : <8x8xbf16>
+ %2 = amx.tile_zero : <4x4xf32>
// expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
- %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
+ %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 3cacbd0044f8251..82fb0a9c99f71de 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,33 +14,49 @@
// CHECK: amx.tilestored64
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x64xi8>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
- %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
- %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
+ %1 = amx.tile_zero : <16x64xi8>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
+ %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, <16x16xi32>
return
}
-// CHECK-LABEL: mulf(
+// CHECK-LABEL: mulbf16(
// CHECK: amx.tilezero
// CHECK: amx.tileloadd64
// CHECK: amx.tileloadd64
// CHECK: amx.tdpbf16ps
// CHECK: amx.tilestored64
-func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
+func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x32xbf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
+ %1 = amx.tile_zero : <16x32xbf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulfp16(
+// CHECK: amx.tilezero
+// CHECK: amx.tileloadd64
+// CHECK: amx.tileloadd64
+// CHECK: amx.tdpfp16ps
+// CHECK: amx.tilestored64
+func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = amx.tile_zero : <16x32xf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
return
}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index f2ac5e47f6c3576..62bc071cab1e029 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : vector<16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, vector<16x16xbf16>
+// CHECK: amx.tile_zero : <16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : vector<16x16xbf16>
- amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, vector<16x16xbf16>
+ %1 = amx.tile_zero : <16x16xbf16>
+ amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
return
}
// CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into vector<16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into vector<16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, vector<16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into vector<16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into vector<16x16xf32>
- %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
- amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, vector<16x16xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
+ %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
return
}
// CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
- %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
// Verify the various `zext` combinations.
- %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
- %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
return
}
>From 3a07dfc460a533fbf86440655cef736d93357f88 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 7 Oct 2024 16:46:03 +0000
Subject: [PATCH 2/4] Fix review comments.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../Conversion/LLVMCommon/TypeConverter.h | 2 +-
mlir/include/mlir/Dialect/AMX/AMX.td | 67 +++++++++++--------
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 4 +-
mlir/test/Dialect/AMX/invalid.mlir | 18 ++---
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 42 ++++++------
mlir/test/Dialect/AMX/roundtrip.mlir | 50 +++++++-------
6 files changed, 97 insertions(+), 86 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index bd4b3e73f074108..6e01f70ee351235 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -259,7 +259,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
- /// Convert AMX tile type x86_amx type.
+ /// Convert an AMX tile type to the x86_amx type.
Type convertAMXTileType(amx::TileType type) const;
/// Options for customizing the llvm lowering.
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 8ef5ac25fbbddff..8a51df1ea183fc7 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}
-def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";
let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
let skipDefaultBuilders = 1;
}
-def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
-
-def IsAMX2DTilePred : And<[IsAMXTilePred,
+def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-class AMX2DTileOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+class AMXTileOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::amx::TileType">;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+
+def AMXTileF32 : AMXTileOf<[F32]>;
+
+def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+
+def AMXTileI32 : AMXTileOf<[I32]>;
+
+def AMXTileI8 : AMXTileOf<[I8]>;
+
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
- %0 = amx.tile_zero : <16x16xbf16>
+ %0 = amx.tile_zero : !amx.tile<16x16xbf16>
```
}];
- let results = (outs
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+ let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
- let assemblyFormat = "attr-dict `:` type($res)";
+ let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}
@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
- let results = (outs
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
+ let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
- "type($base) `into` type($res)";
+ "type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
+ AnyAMXTile:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
- "type($base) `,` type($val)";
+ "type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
Example:
```mlir
- %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
+ %0 = amx.tile_mulf %a, %b, %c
+ : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
```
}];
- let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
- AMX2DTileOf<[F16, BF16]>:$rhs,
- AMX2DTileOf<[F32]>:$acc);
- let results = (outs AMX2DTileOf<[F32]>:$res);
+ let arguments = (ins AMXTileF16OrBF16:$lhs,
+ AMXTileF16OrBF16:$rhs,
+ AMXTileF32:$acc);
+ let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
- "type($lhs) `,` type($rhs) `,` type($acc) ";
+ "qualified(type($lhs)) `,` qualified(type($rhs))"
+ " `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
Example:
```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %0 = amx.tile_muli %a zext, %b zext, %c
+ : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
```
}];
- let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
- AMX2DTileOf<[I8]>:$rhs,
- AMX2DTileOf<[I32]>:$acc,
+ let arguments = (ins AMXTileI8:$lhs,
+ AMXTileI8:$rhs,
+ AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
- let results = (outs AMX2DTileOf<[I32]>:$res);
+ let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
- "type($lhs) `,` type($rhs) `,` type($acc) ";
+ "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 3172d93ef65bc99..14fb8934040de9e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -158,10 +158,12 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
- else
+ else if (aType.getElementType().isF16())
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs());
+ else
+ llvm_unreachable("Unexpected element type for amx.mulf");
return success();
}
};
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 46f3b7448148b36..8febe1605e33a4a 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -4,21 +4,21 @@
func.func @rowheight() {
// expected-error at +1 {{'amx.tile_zero' op bad row height: 17}}
- %0 = amx.tile_zero : <17x16xbf16>
+ %0 = amx.tile_zero : !amx.tile<17x16xbf16>
}
// -----
func.func @colwidth() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 65}}
- %0 = amx.tile_zero : <16x65xi8>
+ %0 = amx.tile_zero : !amx.tile<16x65xi8>
}
// -----
func.func @col4bytemultiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
- %0 = amx.tile_zero : <16x5xi8>
+ %0 = amx.tile_zero : !amx.tile<16x5xi8>
}
// -----
@@ -26,7 +26,7 @@ func.func @col4bytemultiple() {
func.func @memtilesize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
}
// -----
@@ -34,15 +34,15 @@ func.func @memtilesize(%arg0: memref<?x?xf32>) {
func.func @memindexsize(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
- %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
+ %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
}
// -----
func.func @multsize() {
- %0 = amx.tile_zero : <8x8xbf16>
- %1 = amx.tile_zero : <8x8xbf16>
- %2 = amx.tile_zero : <4x4xf32>
+ %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+ %1 = amx.tile_zero : !amx.tile<8x8xbf16>
+ %2 = amx.tile_zero : !amx.tile<4x4xf32>
// expected-error at +1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
- %3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
+ %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 82fb0a9c99f71de..cb827e6f8eca26a 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -14,17 +14,17 @@
// CHECK: amx.tilestored64
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x64xi8>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
- %5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
- %7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, <16x16xi32>
+ %1 = amx.tile_zero : !amx.tile<16x64xi8>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
+ %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
return
}
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
// CHECK: amx.tilestored64
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x32xbf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_zero : !amx.tile<16x32xbf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: amx.tilestored64
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x32xf16>
- %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
- %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_zero : !amx.tile<16x32xf16>
+ %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
+ %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir
index 62bc071cab1e029..1b7f781ae173d39 100644
--- a/mlir/test/Dialect/AMX/roundtrip.mlir
+++ b/mlir/test/Dialect/AMX/roundtrip.mlir
@@ -1,49 +1,49 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tzero
-// CHECK: amx.tile_zero : <16x16xbf16>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
+// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_zero : <16x16xbf16>
- amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
+ %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+ amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
return
}
// CHECK-LABEL: tmulf
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
-// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
+// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
- %3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
- amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// CHECK-LABEL: tmuli
-// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
-// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
-// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
+// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
+// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
- %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
- %4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
+ %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the various `zext` combinations.
- %5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- %6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
- %7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
+ %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
+ %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
return
}
>From b6802a396db25f12a88cb284c4322c7b42aa3744 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Fri, 1 Nov 2024 21:19:33 +0000
Subject: [PATCH 3/4] Adjust new test.
---
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index cb827e6f8eca26a..8085f5f59fcaf01 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -79,11 +79,11 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
- %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
- amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
- amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
- amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
+ amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
+ amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
+ amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
return
}
>From 37a98e6b0a9e79f705c24fd75083227604a35ac8 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Fri, 1 Nov 2024 21:31:05 +0000
Subject: [PATCH 4/4] Remove amx::TileType conversion from LLVMTypeConverter.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../Conversion/LLVMCommon/TypeConverter.h | 4 ---
mlir/include/mlir/Dialect/AMX/Transforms.h | 8 ++++--
mlir/include/mlir/InitAllExtensions.h | 2 ++
mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 -
.../Conversion/LLVMCommon/TypeConverter.cpp | 7 ------
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 25 ++++++++++++++++++-
6 files changed, 32 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 6e01f70ee351235..d79b90f840ce836 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,7 +15,6 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -259,9 +258,6 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
- /// Convert an AMX tile type to the x86_amx type.
- Type convertAMXTileType(amx::TileType type) const;
-
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index d00ac52e274f9f7..7391ec2ff6b14a4 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -14,16 +14,20 @@ namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
+class DialectRegistry;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
+
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 2a241fa4b192fee..1f2ef26b450701a 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -24,6 +24,7 @@
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertNVVMToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 39199e4affccfac..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,7 +12,6 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
Core
LINK_LIBS PUBLIC
- MLIRAMXDialect
MLIRIR
MLIRLLVMDialect
MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 07e6dd55f855191..4e7758bf46d9cfc 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,7 +67,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return std::nullopt;
return llvmType;
});
- addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -595,12 +594,6 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return vectorType;
}
-/// Convert an AMX tile type to LLVM x86_amx type.
-/// Shape and element type of the tile are ignored.
-Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
- return LLVM::LLVMX86AMXType::get(&getContext());
-}
-
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 14fb8934040de9e..4eac371d4c1ae4f 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
@@ -208,9 +209,12 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
TileMulFConversion, TileMulIConversion>(converter);
+ converter.addConversion([&](amx::TileType type) {
+ return LLVM::LLVMX86AMXType::get(&converter.getContext());
+ });
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
@@ -220,3 +224,22 @@ void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
TileMulFOp>();
}
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+ dialect->addInterfaces<AMXToLLVMDialectInterface>();
+ });
+}
More information about the Mlir-commits
mailing list