[Mlir-commits] [mlir] [MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix (PR #132312)
Uday Bondhugula
llvmlistbot at llvm.org
Thu Mar 20 23:02:27 PDT 2025
https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/132312
>From f89c3326da94491f9be361ff079a4704242a57d6 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Fri, 21 Mar 2025 04:18:45 +0530
Subject: [PATCH] [MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix
Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.
---
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 6 +++-
.../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 4 +++
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 16 +++++++++--
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 3 ++
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 7 +++++
mlir/test/Dialect/Affine/loop-fusion-4.mlir | 28 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 4 +++
7 files changed, 64 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..aefa50947f758 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
- : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+ : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+ MemRefElementTypeInterface::Trait> {
public:
using Base::Base;
@@ -163,6 +164,9 @@ class MMAMatrixType
/// Get elementType of a single element.
Type getElementType() const;
+ /// Implementation for MemRefElementTypeInterface.
+ unsigned getAnalysisSizeInBytes() const;
+
/// The general form of operation this type supports is given by the equation
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 857e68cec8c76..3f8bb0c6ea90a 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
+ let extraClassDeclaration = [{
+ /// Best effort size for analysis purposes.
+ unsigned getAnalysisSizeInBytes() { return 8; }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
For example, scalar values such as integers can implement this interface,
but indicator types such as `void` or `unit` should not.
- The interface currently has no methods and is used by types to opt into
- being memref elements. This may change in the future, in particular to
- require types to provide their size or alignment given a data layout.
+ The interface currently has one method and is mainly used by types to opt
+ into being memref elements. This may change in the future, in particular to
+ require types to provide actual size or alignment given a data layout.
}];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the size of the element type in bytes for purposes such as
+ analysis. Such a size is meant to be used in analysis costs models as a
+ best effort in the absence of data layout, as opposed to for
+ target-specific lowering which would require a data layout.
+ }],
+ "unsigned", "getAnalysisSizeInBytes">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
else
return std::nullopt;
+ } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+ memRefType.getElementType())) {
+ sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
} else {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
elementType.isInteger(32);
}
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+ // The underlying element type is expected to always be int or float and
+ // typically divisible by 8 bits.
+ return ShapedType::getNumElements(getShape()) *
+ llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
LogicalResult
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..e948f8ad74bc9 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,31 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
return
}
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ }
+ }
+
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ }
+ }
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 8 {
+ // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 4 {
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+
+ return
+ // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
[MemRefElementTypeInterface]> {
let mnemonic = "memref_element";
+
+ let extraClassDeclaration = [{
+ unsigned getAnalysisSizeInBytes() const { return 1; }
+ }];
}
def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
More information about the Mlir-commits
mailing list