[Mlir-commits] [mlir] [mlir][MemRef][GPU] Migrate GPU dialect ops to IndexedAccessOpInterface (PR #190380)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Apr 3 11:10:29 PDT 2026


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/190380

This commit migrates the handling of GPU dialect ops in fold-memref-alias-ops from hard-coded support to the new IndexedAccessOphinterface, which also adds expand_shape folding support for those ops.

Once other memref-dialect passes are migrated to use this interface, this will allow us to break the dependency between the memref and gpu dialects.

>From 2f187ca656c01584f5d072a565741e81d3df7cfc Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 3 Apr 2026 18:07:43 +0000
Subject: [PATCH] [mlir][MemRef][GPU] Migrate GPU dialect ops to
 IndexedAccessOpInterface

This commit migrates the handling of GPU dialect ops in
fold-memref-alias-ops from hard-coded support to the new
IndexedAccessOphinterface, which also adds expand_shape folding
support for those ops.

Once other memref-dialect passes are migrated to use this interface,
this will allow us to break the dependency between the memref and gpu
dialects.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |   8 ++
 .../Transforms/IndexedAccessOpInterfaceImpl.h |  21 ++++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |   3 +
 .../IndexedAccessOpInterfaceImpl.cpp          | 115 ++++++++++++++++++
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  |  20 ---
 mlir/lib/RegisterAllDialects.cpp              |   2 +
 .../Dialect/GPU/fold-memref-alias-ops.mlir    |  79 ++++++++++++
 8 files changed, 229 insertions(+), 20 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/GPU/fold-memref-alias-ops.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f0a4dd44c8f67..c055f73e9c99f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1828,6 +1828,7 @@ def GPU_SetDefaultDeviceOp : GPU_Op<"set_default_device",
   let assemblyFormat = "attr-dict $devIndex";
 }
 
+// Promises IndexedAccessOpInterface.
 def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     [MemoryEffects<[MemRead]>]>{
 
@@ -1845,6 +1846,9 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     matrix which eventually allows the lowering to determine the size of each
     row.  If the `transpose` attribute is present then the op does a transposed load.
 
+    The memory indices along each dimension must be in-bounds for that dimension
+    as with an ordinary `memref.load`.
+
     For integer types, the resulting `!gpu.mma_matrix` type needs to specify the
     signedness of the data if the matrix type is an `A` or `B` operand for
     `gpu.subgroup_mma_compute`.
@@ -1874,6 +1878,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
   let hasVerifier = 1;
 }
 
+// Promises IndexedAccessOpInterface.
 def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     [MemoryEffects<[MemWrite]>]>{
 
@@ -1893,6 +1898,9 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     This op is often meant to be used along with `gpu.subgroup_mma_load_matrix` and
     `gpu.subgroup_mma_compute`.
 
+    The memory indices along each dimension must be in-bounds for that dimension
+    as with an ordinary `memref.load`.
+
     Example:
 
     ```mlir
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h b/mlir/include/mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h
new file mode 100644
index 0000000000000..d8a56545fd115
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- IndexedAccessOpInterfaceImpl.h - -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_GPU_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace gpu {
+void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GPU_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index f2f010a771b77..547812da0ab97 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -34,6 +34,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/EliminateBarriers.cpp
   Transforms/GlobalIdRewriter.cpp
   Transforms/KernelOutlining.cpp
+  Transforms/IndexedAccessOpInterfaceImpl.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ModuleToBinary.cpp
   Transforms/NVVMAttachTarget.cpp
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8039f3952eea6..f0fbd093d3446 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -301,6 +301,9 @@ void GPUDialect::initialize() {
                             BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
                             LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
                             SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
+  declarePromisedInterfaces<memref::IndexedAccessOpInterface,
+                            SubgroupMmaLoadMatrixOp,
+                            SubgroupMmaStoreMatrixOp>();
 }
 
 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
diff --git a/mlir/lib/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.cpp b/mlir/lib/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..2423dd6b4d037
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.cpp
@@ -0,0 +1,115 @@
+//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+using namespace mlir::gpu;
+
+/// Given a GPU matrix type that will be loaded or stored, the leading dimension
+/// of the matrix in memory, and whether or not the matrix is transposed,
+/// compute the size of the linear memory that the load/store spans as
+/// dC + leadingDim * (dR - 1) where dR and dC are the non-contiguous and
+/// contiguous matrix dimensions, respectively (we get to the dX-1th row and
+/// then access the first dY elements of it).
+static int64_t get1DAccessSize(MMAMatrixType matrixType, int64_t leadingDim,
+                               bool transpose) {
+  assert(matrixType.getShape().size() == 2 && "expected matrices to be 2D");
+
+  int64_t c = matrixType.getShape()[1];
+  int64_t r = matrixType.getShape()[0];
+  if (transpose)
+    std::swap(c, r);
+  return c + leadingDim * (r - 1);
+}
+
+namespace {
+struct SubgroupMmaLoadMatrixOpImpl final
+    : IndexedAccessOpInterface::ExternalModel<SubgroupMmaLoadMatrixOpImpl,
+                                              SubgroupMmaLoadMatrixOp> {
+  TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
+    return cast<SubgroupMmaLoadMatrixOp>(op).getSrcMemref();
+  }
+
+  Operation::operand_range getIndices(Operation *op) const {
+    return cast<SubgroupMmaLoadMatrixOp>(op).getIndices();
+  }
+
+  /// This returns a 1-D shape so that it's clear that both linearization and
+  /// folding in expand/collapse_shape operations are allowed.
+  SmallVector<int64_t> getAccessedShape(Operation *op) const {
+    auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
+    return {get1DAccessSize(cast<MMAMatrixType>(loadOp.getRes().getType()),
+                            loadOp.getLeadDimension().getZExtValue(),
+                            loadOp.getTranspose().value_or(false))};
+  }
+
+  std::optional<SmallVector<Value>>
+  updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
+                         ValueRange newIndices) const {
+    auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
+    rewriter.modifyOpInPlace(loadOp, [&]() {
+      loadOp.getSrcMemrefMutable().assign(newMemref);
+      loadOp.getIndicesMutable().assign(newIndices);
+    });
+    return std::nullopt;
+  }
+
+  bool hasInboundsIndices(Operation *) const { return true; }
+};
+
+struct SubgroupMmaStoreMatrixOpImpl final
+    : IndexedAccessOpInterface::ExternalModel<SubgroupMmaStoreMatrixOpImpl,
+                                              SubgroupMmaStoreMatrixOp> {
+  TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
+    return cast<SubgroupMmaStoreMatrixOp>(op).getDstMemref();
+  }
+
+  Operation::operand_range getIndices(Operation *op) const {
+    return cast<SubgroupMmaStoreMatrixOp>(op).getIndices();
+  }
+
+  /// This returns a 1-D shape so that it's clear that both linearization and
+  /// folding in expand/collapse_shape operations are allowed.
+  SmallVector<int64_t> getAccessedShape(Operation *op) const {
+    auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
+    return {get1DAccessSize(storeOp.getSrc().getType(),
+                            storeOp.getLeadDimension().getZExtValue(),
+                            storeOp.getTranspose().value_or(false))};
+  }
+
+  std::optional<SmallVector<Value>>
+  updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
+                         ValueRange newIndices) const {
+    auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
+    rewriter.modifyOpInPlace(storeOp, [&]() {
+      storeOp.getDstMemrefMutable().assign(newMemref);
+      storeOp.getIndicesMutable().assign(newIndices);
+    });
+    return std::nullopt;
+  }
+
+  bool hasInboundsIndices(Operation *) const { return true; }
+};
+} // namespace
+
+void mlir::gpu::registerIndexedAccessOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    SubgroupMmaLoadMatrixOp::attachInterface<SubgroupMmaLoadMatrixOpImpl>(*ctx);
+    SubgroupMmaStoreMatrixOp::attachInterface<SubgroupMmaStoreMatrixOpImpl>(
+        *ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 6f2752932422a..023b7a3d86b2f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -92,14 +92,6 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getBase();
 }
 
-static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
-  return op.getSrcMemref();
-}
-
-static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
-  return op.getDstMemref();
-}
-
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
@@ -372,11 +364,6 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
                 subViewOp.getDroppedDims())),
             op.getPadding(), op.getMask(), op.getInBoundsAttr());
       })
-      .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
-        rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
-            op, op.getType(), subViewOp.getSource(), sourceIndices,
-            op.getLeadDimension(), op.getTransposeAttr());
-      })
       .Case([&](nvgpu::LdMatrixOp op) {
         rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
             op, op.getType(), subViewOp.getSource(), sourceIndices,
@@ -543,11 +530,6 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             op, subViewOp.getSource(), sourceIndices, op.getMask(),
             op.getValueToStore());
       })
-      .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
-        rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
-            op, op.getSrc(), subViewOp.getSource(), sourceIndices,
-            op.getLeadDimension(), op.getTransposeAttr());
-      })
       .DefaultUnreachable("unexpected operation");
   return success();
 }
@@ -885,11 +867,9 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
       LoadOpOfSubViewOpFolder<vector::LoadOp>,
       LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
       LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
-      LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
       StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
       StoreOpOfSubViewOpFolder<vector::StoreOp>,
       StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
-      StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
       LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
       LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
       LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index ea5698f39c0b0..7a79e3408f1b8 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -38,6 +38,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -168,6 +169,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferizableOpInterfaceExternalModels(registry);
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+  gpu::registerIndexedAccessOpInterfaceExternalModels(registry);
   gpu::registerValueBoundsOpInterfaceExternalModels(registry);
   LLVM::registerInlinerInterface(registry);
   NVVM::registerInlinerInterface(registry);
diff --git a/mlir/test/Dialect/GPU/fold-memref-alias-ops.mlir b/mlir/test/Dialect/GPU/fold-memref-alias-ops.mlir
new file mode 100644
index 0000000000000..f5ff5e29681f4
--- /dev/null
+++ b/mlir/test/Dialect/GPU/fold-memref-alias-ops.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s | FileCheck %s
+
+func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref<?xvector<4xf32>>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
+  %subview = memref.subview %src[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
+  %matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
+  return %matrix: !gpu.mma_matrix<16x16xf16, "COp">
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//      CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_1d
+// CHECK-SAME: (%[[SRC:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I:.+]]: index)
+//      CHECK:   %[[APPLY:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[I]]]
+//      CHECK:   %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[APPLY]]] {leadDimension = 160 : index} : memref<?xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
+//      CHECK:   return %[[LOAD]]
+
+// -----
+
+func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>, %offset: index, %i: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
+  %subview = memref.subview %dst[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
+  gpu.subgroup_mma_store_matrix %matrix, %subview[%i] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<81920xvector<4xf32>, strided<[1], offset: ?>>
+  return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//      CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_1d
+// CHECK-SAME: (%[[DST:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[VAL:.+]]: !gpu.mma_matrix<16x16xf16, "COp">)
+//      CHECK:   %[[APPLY:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[I0]]]
+//      CHECK:   gpu.subgroup_mma_store_matrix %[[VAL]], %[[DST]][%[[APPLY]]] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<?xvector<4xf32>>
+
+// -----
+
+// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
+//  CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
+func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
+  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
+  // CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
+  %matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[256, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
+  return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
+//  CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
+func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
+  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
+  // CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
+  gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} :  !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[256, 1], offset: ?>>
+  return
+}
+
+// -----
+
+func.func @fold_gpu_subgroup_mma_load_matrix_expand_shape(%src: memref<4096xf32>, %i: index, %j: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
+  %expand = memref.expand_shape %src [[0, 1]] output_shape [64, 64] : memref<4096xf32> into memref<64x64xf32>
+  %matrix = gpu.subgroup_mma_load_matrix %expand[%i, %j] {leadDimension = 64 : index} : memref<64x64xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
+  return %matrix: !gpu.mma_matrix<16x16xf16, "COp">
+}
+
+//      CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_expand_shape
+// CHECK-SAME: (%[[SRC:.+]]: memref<4096xf32>, %[[I:.+]]: index, %[[J:.+]]: index)
+//      CHECK:   %[[LIN:.+]] = affine.linearize_index disjoint [%[[I]], %[[J]]] by (64, 64)
+//      CHECK:   %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[LIN]]] {leadDimension = 64 : index}
+//      CHECK:   return %[[LOAD]]
+
+// -----
+
+func.func @fold_gpu_subgroup_mma_store_matrix_expand_shape(%dst: memref<4096xf32>, %i: index, %j: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
+  %expand = memref.expand_shape %dst [[0, 1]] output_shape [64, 64] : memref<4096xf32> into memref<64x64xf32>
+  gpu.subgroup_mma_store_matrix %matrix, %expand[%i, %j] {leadDimension = 64 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x64xf32>
+  return
+}
+
+//      CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_expand_shape
+// CHECK-SAME: (%[[DST:.+]]: memref<4096xf32>, %[[I:.+]]: index, %[[J:.+]]: index, %[[MATRIX:.+]]: !gpu.mma_matrix<16x16xf16, "COp">)
+//      CHECK:   %[[LIN:.+]] = affine.linearize_index disjoint [%[[I]], %[[J]]] by (64, 64)
+//      CHECK:   gpu.subgroup_mma_store_matrix %[[MATRIX]], %[[DST]][%[[LIN]]] {leadDimension = 64 : index}
+//      CHECK:   return
+



More information about the Mlir-commits mailing list