[Mlir-commits] [mlir] [MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix (PR #132312)

Uday Bondhugula llvmlistbot at llvm.org
Thu Mar 20 23:02:27 PDT 2025


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/132312

>From f89c3326da94491f9be361ff079a4704242a57d6 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Fri, 21 Mar 2025 04:18:45 +0530
Subject: [PATCH] [MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix

Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h |  6 +++-
 .../include/mlir/Dialect/Ptr/IR/PtrDialect.td |  4 +++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 16 +++++++++--
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    |  3 ++
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  7 +++++
 mlir/test/Dialect/Affine/loop-fusion-4.mlir   | 28 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  4 +++
 7 files changed, 64 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..aefa50947f758 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
 ///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 // TODO: consider moving this to ODS.
 class MMAMatrixType
-    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+                            MemRefElementTypeInterface::Trait> {
 public:
   using Base::Base;
 
@@ -163,6 +164,9 @@ class MMAMatrixType
   /// Get elementType of a single element.
   Type getElementType() const;
 
+  /// Implementation for MemRefElementTypeInterface.
+  unsigned getAnalysisSizeInBytes() const;
+
   /// The general form of operation this type supports is given by the equation
   /// C += A*B. This function returns which operand in the given equation is
   /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 857e68cec8c76..3f8bb0c6ea90a 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
       return $_get(memorySpace.getContext(), memorySpace);
     }]>
   ];
+  let extraClassDeclaration = [{
+    /// Best effort size for analysis purposes.
+    unsigned getAnalysisSizeInBytes() { return 8; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
     For example, scalar values such as integers can implement this interface,
     but indicator types such as `void` or `unit` should not.
 
-    The interface currently has no methods and is used by types to opt into
-    being memref elements. This may change in the future, in particular to
-    require types to provide their size or alignment given a data layout.
+    The interface currently has one method and is mainly used by types to opt
+    into being memref elements. This may change in the future, in particular to
+    require types to provide actual size or alignment given a data layout.
   }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the size of the element type in bytes for purposes such as
+      analysis. Such a size is meant to be used in analysis costs models as a
+      best effort in the absence of data layout, as opposed to for
+      target-specific lowering which would require a data layout.
+    }],
+    "unsigned", "getAnalysisSizeInBytes">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
           vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
     else
       return std::nullopt;
+  } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+                 memRefType.getElementType())) {
+    sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
   } else {
     return std::nullopt;
   }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
          elementType.isInteger(32);
 }
 
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+  // The underlying element type is expected to always be int or float and
+  // typically divisible by 8 bits.
+  return ShapedType::getNumElements(getShape()) *
+         llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
 LogicalResult
 MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
                                 ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..e948f8ad74bc9 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,31 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
   // PRODUCER-CONSUMER-MAXIMAL:          affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
   return
 }
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+    }
+  }
+
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+    }
+  }
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 8 {
+  // PRODUCER-CONSUMER-NEXT:   affine.for %{{.*}} = 0 to 4 {
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+
+  return
+  // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
 def TestMemRefElementType : Test_Type<"TestMemRefElementType",
                                       [MemRefElementTypeInterface]> {
   let mnemonic = "memref_element";
+
+  let extraClassDeclaration = [{
+    unsigned getAnalysisSizeInBytes() const { return 1; }
+  }];
 }
 
 def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;



More information about the Mlir-commits mailing list