[Mlir-commits] [mlir] [mlir][vector-to-gpu]: Extend MMA Lowerings (PR #176785)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 09:51:51 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

Add support for lowering non-minor-identity maps during `vector.transfer_read` and `vector.transfer_write` to `gpu.subgroup_mma_load_matrix` and `gpu.subgroup_mma_store_matrix` lowerings. If
the permutation map is a "strided minor identity", that is it jumps some intermediate dimensions e.g. (d0, d1, d2) -> (d0, d2) then we can express this stride in the `leadDimension` attribute of the `gpu.subgroup_mma_load_matrix` and stride over the missing intermediate dimensions when we load.

---
Full diff: https://github.com/llvm/llvm-project/pull/176785.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+37-13) 
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir (+20) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..28448ad7106d9 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -116,9 +116,26 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
 }
 
-// Return the stide for the second-to-last dimension of |type| if it is a memref
+// Return true if the given map represents a minor identity map with "strided"
+// dimensions i.e. (d0, d1, ..., dn-1) -> (da, dn-1) where 0 <= a <= n-1.
+//
+// This currently doesn't support permuted or broadcast dimensions.
+static bool isStridedMinorIdentity(AffineMap permutationMap) {
+  if (permutationMap.getNumResults() != 2)
+    return false;
+
+  AffineExpr innerResult = permutationMap.getResult(1);
+  const unsigned nDim = permutationMap.getNumDims();
+  if (innerResult != getAffineDimExpr(nDim - 1, permutationMap.getContext()))
+    return false;
+
+  return isa<AffineDimExpr>(permutationMap.getResult(0));
+}
+
+// Return the stide for the "row" dimension of |type| if it is a memref
 // and has a constant stride.
-static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
+static std::optional<int64_t>
+getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
   auto memrefType = dyn_cast<MemRefType>(type);
   if (!memrefType)
     return false;
@@ -130,7 +147,13 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
   if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
       strides.back() != 1)
     return std::nullopt;
-  int64_t stride = strides[strides.size() - 2];
+
+  int stridePostion = strides.size() - 2;
+  if (!permutationMap.isPermutation()) {
+    if (auto outerResult = dyn_cast<AffineDimExpr>(permutationMap.getResult(0)))
+      stridePostion = outerResult.getPosition();
+  }
+  int64_t stride = strides[stridePostion];
   if (stride == ShapedType::kDynamic)
     return std::nullopt;
   return stride;
@@ -141,7 +164,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
       readOp.getVectorType().getRank() != 2)
     return false;
-  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
+
+  AffineMap map = readOp.getPermutationMap();
+  if (!getStaticallyKnownRowStride(readOp.getShapedType(), map))
     return false;
 
   // Only allow integer types if the signedness can be inferred.
@@ -150,13 +175,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
       return false;
 
-  AffineMap map = readOp.getPermutationMap();
   MLIRContext *ctx = readOp.getContext();
   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
   AffineExpr zero = getAffineConstantExpr(0, ctx);
   auto broadcastInnerDim =
       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
-  return map.isMinorIdentity() || map == broadcastInnerDim ||
+  return isStridedMinorIdentity(map) || map == broadcastInnerDim ||
          isTransposeMatrixLoadMap(map);
 }
 
@@ -170,12 +194,12 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
       writeOp.getVectorType().getRank() != 2)
     return false;
-  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
+
+  AffineMap map = writeOp.getPermutationMap();
+  if (!getStaticallyKnownRowStride(writeOp.getShapedType(), map))
     return false;
   // TODO: Support transpose once it is added to GPU dialect ops.
-  if (!writeOp.getPermutationMap().isMinorIdentity())
-    return false;
-  return true;
+  return isStridedMinorIdentity(map);
 }
 
 /// Return true if the constant is a splat to a 2D vector so that it can be
@@ -547,14 +571,14 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   assert(transferReadSupportsMMAMatrixType(op) &&
          "expected convertible operation");
 
+  AffineMap map = op.getPermutationMap();
   std::optional<int64_t> stride =
-      getStaticallyKnownRowStride(op.getShapedType());
+      getStaticallyKnownRowStride(op.getShapedType(), map);
   if (!stride.has_value()) {
     LDBG() << "no stride";
     return rewriter.notifyMatchFailure(op, "no stride");
   }
 
-  AffineMap map = op.getPermutationMap();
   bool isTranspose = isTransposeMatrixLoadMap(map);
 
   // Handle broadcast by setting the stride to 0.
@@ -597,7 +621,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
 
   assert(transferWriteSupportsMMAMatrixType(op));
   std::optional<int64_t> stride =
-      getStaticallyKnownRowStride(op.getShapedType());
+      getStaticallyKnownRowStride(op.getShapedType(), op.getPermutationMap());
   if (!stride.has_value()) {
     LDBG() << "no stride";
     return rewriter.notifyMatchFailure(op, "no stride");
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index ef72901750479..71a748d33831e 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -555,3 +555,23 @@ func.func @addf(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memre
   vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @matmul_with_strides
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 96 : index} : memref<16x6x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 144 : index} : memref<16x9x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 144 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x9x16xf16>
+func.func @matmul_with_strides(%arg0: memref<16x16xf16>, %arg1: memref<16x6x16xf16>, %arg2: memref<16x9x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>, in_bounds = [true, true]} : memref<16x6x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : memref<16x9x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : vector<16x16xf16>, memref<16x9x16xf16>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176785


More information about the Mlir-commits mailing list